Visualising State Space Representations of LSTM Networks

Emmanuel M. Smith1, Jim Smith1, Phil Legg1, Simon Francis2

1 Department of Computer Science and Creative Technologies
University of the West of England, Bristol, United Kingdom

2 Montvieux Limited, Gloucestershire, United Kingdom

Long Short-Term Memory (LSTM) networks have proven to be one of the most effective models for making predictions on sequence-based tasks [1, 2]. These models work by capturing, remembering, and forgetting information relevant to their future predictions. The non-linear complexity of the mechanisms involved in this process means we currently lack tools for achieving interpretability [3-5]. Ideally, we want these models to provide an explanation of why they make a particular prediction, given a specific input. Researchers have explored the idea of interpreting LSTMs in specific contexts such as Natural Language Processing (NLP) or classification [6-8], but they put minimal focus on approaches which are generalisable across different applications. To alleviate this, in this work, we demonstrate a method which enables the interpretation and comparison of LSTM states during time series predictions. We show that by reducing the dimensionality of network states one can scalably visualise patterns and explain model behaviours.

In this notebook we first train a small LSTM network on a basic square wave prediction task and allude to the complexities which arise whilst attempting to identify patterns in the states of LSTMs. We then train a larger LSTM network on the same problem to demonstrate how this problem is exacerbated by increasing the amount of nodes in a network. Next, we discuss the novel concept of using state space representations which allow for the visualisation and comparison of network states in a scalable manner. We first apply this technique towards the initial square wave problem to demonstrate the improvements it offers over examining the unmodified states. After this, we apply the same approach to a more complex signal which involves alternating wave types. Finally, we utilise the visualisation technique on a real-world dataset which involves water usage readings from a university campus.

Basic Examples of LSTM Interpretability

The entire purpose of developing new models is to obtain more accurate and powerful tools for predicting signals. Whether those signals are a stream of words, a sensor reading, or a constructed signal, our model must first be able to predict accurately before we can even concern ourselves with the interpretation of those predictions. To demonstrate how one might go about naively interpreting an LSTM model, we first consider a simplified scenario. The scenario involves predicting a standard square wave. Each period contains 16 samples. The model is presented with each sample individually, as input, and it has to predict what the value will be for the next sample in the series.

In [1]:
# Allow for the collapse of the code cells.
from IPython.display import HTML
HTML('''<script>code_show=true;function code_toggle(){if(code_show){$('div.input').hide();}else{$('div.input').show();}code_show = !code_show;} $(document).ready(code_toggle);</script><form action="javascript:code_toggle()"><input type="submit" value="Click here to toggle on/off the raw code."></form>''')
Out[1]:
In [2]:
import copy
import datetime
import numpy
import pandas
import pickle
import plotly.graph_objs as graphs
import plotly.offline as plotly
import random
from sklearn.decomposition import PCA
from sklearn.preprocessing import normalize
import torch
import torch.autograd as autograd
import torch.nn as nn


class WaveLSTM(nn.Module):
    '''Allow for dynamically sizing the LSTM sizes and state extraction..'''
    def __init__(self, features=1, layers=1, nodes=2):
        super().__init__()
        self.features = features
        self.nodes = nodes
        self.hidden_layers = nn.ModuleList([nn.LSTM(self.features, self.nodes)])
        for _ in range(layers - 1):
            self.hidden_layers.append(nn.LSTM(self.nodes, self.nodes))
        self.output_layer = nn.Linear(self.nodes, self.features)
        
        self.states = [None for _ in range(layers)]
        self.reset_state()
            

    def reset_state(self):
        for l in range(len(self.hidden_layers)):
            self.states[l] =  (autograd.Variable(torch.zeros(1, 1, self.nodes)),
                               autograd.Variable(torch.zeros(1, 1, self.nodes)))
    
    def forward(self, input_):
        output, self.states[0] = self.hidden_layers[0](input_.view(1, 1, -1), self.states[0])
        for l in range(1, len(self.hidden_layers)):
            output, self.states[l] = self.hidden_layers[l](output, self.states[l])
        return self.output_layer(output.view(1, -1))


def set_random_seeds(seed=0):
    '''Ensure reproducability.'''
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    
def make_wave(values, tiles=1, offset=None):
    '''Generate repeating wave pattern from given values.'''
    # Randomly generate phase offset if not given.
    if offset is None: 
        length = len(values)
        offset = int(get_random_normal(length//2, length//2, 0, length-1))
    
    # Repeat wave pattern and apply offset.
    values = numpy.tile(values, tiles)
    values = numpy.roll(values, offset, axis=0)
    
    return numpy.arange(len(values)), values


def normalise_wave(wave):
    '''Normalise values for optimal neural network inputs.'''
    values = wave[1]
    if values.ndim == 1:
        values = values.reshape((-1, 1))
        values = normalize(values, norm='max', axis=0)
        values = values.reshape((-1))
    else:
        values = normalize(values, norm='max', axis=0)
    return wave[0], values * 2 - 1


def make_sine(steps=10, tiles=1, offset=0, frequency=1):
    '''Generate a repreating sine wave pattern.'''
    times = numpy.linspace(0, tiles, int(steps*tiles))
    values = numpy.sin(2 * numpy.pi * frequency * times)
    return make_wave(values, 1, offset)


def make_square(steps=10, tiles=1, offset=None):
    '''Generate a repreating square wave pattern.'''
    values = numpy.concatenate((numpy.zeros(steps//2), numpy.ones(steps//2)))
    return make_wave(values, tiles, offset)


def make_saw(steps=10, tiles=1, offset=None):
    '''Generate a repreating saw wave pattern.'''
    values = numpy.arange(steps)
    return make_wave(values, tiles, offset)


def make_triangle(steps=10, tiles=1, offset=None):
    '''Generate a repreating triangle wave pattern.'''
    values = numpy.concatenate((numpy.arange((steps//2) + 1), numpy.arange(steps//2)[:0:-1]))
    return make_wave(values, tiles, offset)


def combine_waves(a, b, tiles=1, offset=None, normalise=True):
    '''Combine two waves to be presented as input features.'''
    # Normalise waves to same scale and range.
    if normalise:
        a, b = normalise_wave(a), normalise_wave(b)
        
    min_length = min(len(a[1]), len(b[1]))
    values = numpy.column_stack((a[1][:min_length], b[1][:min_length]))
    return make_wave(values, tiles=(tiles, 1), offset=offset)


def attach_waves(a, b, tiles=1, offset=None, normalise=True):
    '''Attach one wave to the tail end of another.'''
    # Normalise waves to same scale and range.
    if normalise:
        a, b = normalise_wave(a), normalise_wave(b)
        
    values = numpy.concatenate((a[1], b[1]))
    return make_wave(values, tiles=tiles, offset=offset)
    

def make_traces_2d(trace_data):
    '''Generate data traces for a 2D scatter graph.'''
    traces = []
    for trace in trace_data:
        traces.append(graphs.Scatter(**trace))
    return traces


def make_figure_2d(traces, name, x_axis, y_axis, annotations=[]):
    '''Generate plotable figure of a 2D scatter graph.'''
    layout = graphs.Layout(
        title=name,
        xaxis=graphs.XAxis(title=x_axis, rangeslider=dict()),
        yaxis=graphs.YAxis(title=y_axis),
        annotations=annotations,
    )
    return graphs.Figure(data=traces, layout=layout)


def extract_activations(states, layer=None, state_type='hidden'):
    '''Extract the internal activations from the recorded states of an LSTM model.'''
    if state_type: 
        state_type = {'hidden': 0, 'cell': 1}[state_type]
    
    activations = []
    for state in states:
        if layer:
            activation = state[layer][state_type].data.numpy().squeeze()
        else:
            activation = state[0][state_type].data.numpy().squeeze()
            activation = numpy.concatenate((activation, state[1][state_type].data.numpy().squeeze()))
        activations.append(activation)
        
    return numpy.array(activations)


def get_random_normal(mean, deviation, minimum=numpy.nan, maximum=numpy.nan):
    '''Generate a bounded random number from a normal distribution.'''
    result = numpy.random.normal(mean, deviation)
    result = min(result, maximum)
    result = max(result, minimum)
    return result



def make_predictions(data, seed_amount, wave_length):
    model.reset_state()
    outputs = []
    states = [copy.copy(model.states)]
    inputs = autograd.Variable(torch.Tensor(data[1][:seed_amount]))
    for input_ in inputs:
        output = model(input_.view(1, 1, -1))
        states.append(copy.copy(model.states))
        outputs.append(output)
    for _ in range(len(data[1]) - seed_amount):
        output = model(output)
        states.append(copy.copy(model.states))
        outputs.append(output)
    return outputs, states

               
def make_prediction_plot(data, outputs, wave_length, seed_amount, title, annotations=[], show_prediction=True):
    traces = make_traces_2d([
        dict(
            name='Presentation (Target)', 
            x=data[0][:wave_length+1],
            y=data[1][:wave_length+1],
            mode='lines'),
        dict(
            name='Seeding (Target)', 
            x=data[0][wave_length:seed_amount+1],
            y=data[1][wave_length:seed_amount+1]),
        dict(
            name = 'Feedback (Target)',
            line = dict(dash='dash'),
            x = data[0][seed_amount:],
            y = data[1][seed_amount:],
        ),
        dict(
            name = 'Prediction (Output)',
            x = data[0][1:],
            y = numpy.array(list(map(lambda x: x.data.numpy(), outputs))).squeeze(),
            visible=True if show_prediction else 'legendonly',
            opacity=0.8,
            line=dict(width=2),
        ),
    ])
    figure = make_figure_2d(traces, title, 'Timestep', 'Value', annotations=annotations)
    return plotly.iplot(figure)


def get_node_activations(states, seed_weighting):
    activations = {
        'network_hidden': extract_activations(states, state_type='hidden'),
        'network_cell': extract_activations(states, state_type='cell'),
        'layer-0_hidden': extract_activations(states, layer=0, state_type='hidden'),
        'layer-0_cell': extract_activations(states, layer=0, state_type='cell'),
        'layer-1_hidden': extract_activations(states, layer=1, state_type='hidden'),
        'layer-1_cell': extract_activations(states, layer=1, state_type='cell'),
    }

    activations_reduced = {}
    for (name, values) in activations.items():
        reducer = PCA(n_components=2)
        activations_reduced[name] = reducer.fit_transform(values[:int(seed_weighting*len(values))])  # Normalisation affects visualisations negatively.
        activations_reduced[name] = numpy.append(activations_reduced[name], reducer.transform(values[int(seed_weighting*len(values)):]), 0)
    node_activations = []
    for node in range(activations['network_cell'].shape[1]):
        node_activations.append(activations['network_cell'][:, node])
    return activations, activations_reduced, node_activations


def make_state_activation_plot(data, node_activations, title, annotations=[], color=-1):
    node_traces = []
    for i, node_activation in enumerate(node_activations):
        scatter = graphs.Scattergl(
            name = 'Node {}'.format(i),
            x = data[0],
            y = node_activation,
            mode = 'lines',
            line = dict(color='rgba(1, 1, 1, 0.33)' if i != color else 'rgba(1, 1, 1, 0.90)', width=1),
        )
        node_traces.append(scatter)
    layout = graphs.Layout(
        title=title,
        xaxis=graphs.XAxis(title='Timestep'),
        yaxis=graphs.YAxis(title='Activation'),
        hovermode='closest',
        annotations=annotations,
    )
    figure = graphs.Figure(data=node_traces, layout=layout)
    return plotly.iplot(figure)


def make_reduced_state_activation_plot_2D_by_position(data, activations_reduced, wave_length, title, annotations=[]):
    trace_0 = graphs.Scatter(
        name = 'PCA Component 0',
        x = data[0],
        y = activations_reduced['network_cell'][:,0],
        mode = 'lines+markers',
        marker = dict(
            size=4,
            color = numpy.tile(numpy.arange(wave_length), 80),
            colorscale = 'Viridis',
        ),
        line = dict(color='rgba(1, 1, 1, 0.5)', width=1)
    )
    trace_1 = graphs.Scatter(
        name = 'PCA Component 1',
        x = data[0],
        y = activations_reduced['network_cell'][:,1],
        mode = 'lines+markers',
        marker = dict(
            size=4,
            color = numpy.tile(numpy.arange(wave_length), 80),
            colorscale = 'Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Wave Position', ypad=90),
        ),
        line = dict(dash='dot', color='rgba(1, 1, 1, 0.5)', width=1),
    )
    layout = graphs.Layout(
        title=title,
        xaxis=graphs.XAxis(title='Timestep', rangeslider=dict()),
        yaxis=graphs.YAxis(title='Activation'),
        annotations=annotations,
    )
    figure = graphs.Figure(data=[trace_0, trace_1], layout=layout)
    return plotly.iplot(figure)


def make_reduced_state_activation_plot_2D_by_value(data, outputs, activations_reduced, wave_length, title, annotations=[]):
    trace_0 = graphs.Scatter(
        name = 'PCA Component 0',
        x = data[0],
        y = activations_reduced['network_cell'][:,0],
        mode = 'lines+markers',
        marker = dict(
            size=4,
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0),
            colorscale = 'Viridis',
        ),
        line = dict(color='rgba(1, 1, 1, 0.5)', width=1)
    )
    trace_1 = graphs.Scatter(
        name = 'PCA Component 1',
        x = data[0],
        y = activations_reduced['network_cell'][:,1],
        mode = 'lines+markers',
        marker = dict(
            size=4,
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0),
            colorscale = 'Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Value', ypad=90),
        ),
        line = dict(dash='dot', color='rgba(1, 1, 1, 0.5)', width=1),
    )
    layout = graphs.Layout(
        title=title,
        xaxis=graphs.XAxis(title='Timestep', rangeslider=dict()),
        yaxis=graphs.YAxis(title='Activation'),
        annotations=annotations,
    )
    figure = graphs.Figure(data=[trace_0, trace_1], layout=layout)
    return plotly.iplot(figure)


def make_state_space_plot_2D_by_position(data, activations_reduced, wave_length, seed_weighting, title, annotations):
    trace_2D = graphs.Scatter(
        name='Presentation',
        x = activations_reduced['network_cell'][:,0][:wave_length],
        y = activations_reduced['network_cell'][:,1][:wave_length],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][:wave_length])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.tile(numpy.arange(wave_length), 80)[:wave_length],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Wave Position', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
        visible='legendonly'
    )
    trace_2D1 = graphs.Scatter(
        name='Seeding',
        x = activations_reduced['network_cell'][:,0][wave_length:int(seed_weighting*len(data[0]))],
        y = activations_reduced['network_cell'][:,1][wave_length:int(seed_weighting*len(data[0]))],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][wave_length:int(seed_weighting*len(data[0]))])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.tile(numpy.arange(wave_length), 80)[wave_length:int(seed_weighting*len(data[0]))],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Wave Position', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
    )
    trace_2D2 = graphs.Scatter(
        name = 'Feedback',
        x = activations_reduced['network_cell'][:,0][int(seed_weighting*len(data[0])):],
        y = activations_reduced['network_cell'][:,1][int(seed_weighting*len(data[0])):],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][int(seed_weighting*len(data[0])):])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.tile(numpy.arange(wave_length), 80)[int(seed_weighting*len(data[0])):],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Wave Position', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
        visible='legendonly'
    )
    layout = graphs.Layout(
        title = title,
        xaxis = dict(title='PCA Component 0'),
        yaxis = dict(title='PCA Component 1'),
        annotations=annotations,
    )
    figure = graphs.Figure(data=[trace_2D, trace_2D1, trace_2D2], layout=layout)
    return plotly.iplot(figure)



def make_state_space_plot_2D_by_value(data, activations_reduced, wave_length, seed_weighting, title, annotations=[]):
    trace_2D = graphs.Scatter(
        name='Presentation',
        x = activations_reduced['network_cell'][:,0][:wave_length],
        y = activations_reduced['network_cell'][:,1][:wave_length],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][:wave_length])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0)[:wave_length],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Prediction', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
        visible='legendonly'
    )
    trace_2D1 = graphs.Scatter(
        name='Seeding',
        x = activations_reduced['network_cell'][:,0][wave_length:int(seed_weighting*len(data[0]))],
        y = activations_reduced['network_cell'][:,1][wave_length:int(seed_weighting*len(data[0]))],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][wave_length:int(seed_weighting*len(data[0]))])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0)[wave_length:int(seed_weighting*len(data[0]))],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Prediction', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
    )
    trace_2D2 = graphs.Scatter(
        name = 'Feedback',
        x = activations_reduced['network_cell'][:,0][int(seed_weighting*len(data[0])):],
        y = activations_reduced['network_cell'][:,1][int(seed_weighting*len(data[0])):],
        text = list(map(lambda t: 'Timestep: ' + str(t), data[0][int(seed_weighting*len(data[0])):])),
        mode = 'lines+markers',
        marker = dict(
            size='8',
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0)[int(seed_weighting*len(data[0])):],
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Prediction', ypad=75),
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.15)',
            width=1,
        ),
        visible='legendonly'
    )
    layout = graphs.Layout(
        title = title,
        xaxis = dict(title='PCA Component 0'),
        yaxis = dict(title='PCA Component 1'),
        annotations=annotations,
    )
    figure = graphs.Figure(data=[trace_2D, trace_2D1, trace_2D2], layout=layout)
    return plotly.iplot(figure)



def make_state_space_plot_3D_by_position(data, activations_reduced, wave_length, title, annotations=[]):
    trace_2D = graphs.Scatter3d(
        x = activations_reduced['network_cell'][:,0],
        y = activations_reduced['network_cell'][:,1],
        z = data[0],
        mode = 'lines+markers',
        marker = dict(
            size='4',
            color = numpy.tile(numpy.arange(wave_length), 80),
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Wave Position')
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.5)',
        ),
    )
    layout = graphs.Layout(
        title = title,
        height = 1000,
        scene = graphs.Scene(
            xaxis = dict(title='PCA 0'),
            yaxis = dict(title='PCA 1'),
            zaxis = dict(title='Step'),
            annotations = annotations,
        )
    )
    figure = graphs.Figure(data=[trace_2D], layout=layout)
    return plotly.iplot(figure)


def make_water_data(days=2, train=True):
    data = pandas.read_csv('water_data.csv')
    meters = ['Brecon 1']
    meter = random.choice(meters)
    data_meter = data[data['meter']==meter]
    data_meter = data_meter.set_index('date')
    data_meter.index = pandas.to_datetime(data_meter.index)
    data_meter = data_meter.resample('1H').sum()
    data_meter = data_meter.rolling('4h').mean()
    water_data = data_meter['value']['2017-10-28':'2017-11-04']
    water_data = ((water_data - water_data.mean()) / water_data.std()) / 2
    return water_data.index, numpy.array(water_data)


def make_state_space_plot_3D_by_value(data, activations_reduced, wave_length, title, annotations=[]):
    trace_2D = graphs.Scatter3d(
        x = activations_reduced['network_cell'][:,0],
        y = activations_reduced['network_cell'][:,1],
        z = data[0],
        mode = 'lines+markers',
        marker = dict(
            size='4',
            color = numpy.concatenate((numpy.array(outputs[0][0][0].data.numpy()), numpy.array(list(map(lambda x: x[0][0].data.numpy(), outputs))).squeeze()), axis=0),
            colorscale='Viridis',
            showscale=True,
            colorbar=graphs.ColorBar(title='Prediction')
        ),
        line = dict(
            color = 'rgba(1, 1, 1, 0.5)',
        ),
    )
    layout = graphs.Layout(
        title = title,
        height = 1000,
        scene = graphs.Scene(
            xaxis = dict(title='PCA 0'),
            yaxis = dict(title='PCA 1'),
            zaxis = dict(title='Step'),
            annotations = annotations,
        )
    )
    figure = graphs.Figure(data=[trace_2D], layout=layout)
    return plotly.iplot(figure, filename='sine_wave_state_trace')
In [3]:
# Allow for interactive plots.
plotly.init_notebook_mode(connected=True)

#Ensure reproducability.
set_random_seeds(0)
In [4]:
# Plot a square wave example.
data = normalise_wave(make_square(steps=16, tiles=7, offset=0))
traces = make_traces_2d([
    dict(
        name='Input Signal',
        x=data[0],
        y=data[1],
        mode='lines'),
])
annotations=[
    dict(x=15, y=1.3, showarrow=False, text='Double click to reset graph.'),
    dict(x=35, y=-1.3, showarrow=False, text='Click and drag anywhere to zoom into a selection.'),
]
plotly.iplot(make_figure_2d(traces, 'Square Wave Example', 'Timestep', 'Value', annotations))

Before we describe how we formatted the task it first helps to understand what an LSTM network does. A LSTM network is similar in structure to a standard neural network. The basic unit of this network is a node which takes in various inputs and produces a predictive output based on that input. Many of these nodes are given the input signal in parallel to form a layer. The outputs of a layer can then be presented to an additional layer to form a more powerful model. This process can be repeated until the last layer which produces the model’s final prediction. Each LSTM node retains some fundamental information about both the current and historical inputs in the form of two internal values, namely the cell and hidden state activations. When an LSTM network acts upon past information in a prediction, these states are responsible for storing that information. Herein lies the big question of LSTM explainability: how can we determine what information the model is storing when it makes a prediction?

In order to adequately demonstrate approaches which begin to tackle this question, we setup the square wave prediction scenario to contain 3 distinct phases. Once trained to recognise and reproduce the signal, we present the model with the wave, point by point, with its state initialised to zeroes. Due to this, the first phase involves presenting one period of the wave to the network in order to allow it to generate a meaningful state. This phase is called the presentation phase, as we are presenting the network with the signal for the first time. During the next phase we continually provide the network with the appropriate input for 2 periods to demonstrate how its state would evolve in a real-time prediction scenario. We call the second phase seeding as we seed the network with the values of the actual square wave, tasking it with only predicting the value of the sample for the next timestep. The final phase exists solely to assess whether or not the network is able to continue to reproduce the input signal without external correction. To achieve this we simply feed the prediction made by the network back into itself as input for the next timestep, rather than giving it the actual value for the wave. Correspondingly, the final phase is called feedback.

Small Network with Square Wave

Since, this problem is particularly trivial we first start with a small pre-trained network which consists of 2 layers with 4 LSTM nodes each, and a final linear layer which combines the outputs of the uppermost LSTM layer into a single prediction. The figure below shows the prediction results across the 3 phases outlined above:

In [5]:
# Simlpe square wave prediction.

# Generate data and load model.
data = normalise_wave(make_square(steps=16, tiles=7, offset=0))
model = WaveLSTM(1, 2, 4)
model.load_state_dict(pickle.load(open('simple_square.p', 'rb')))

# Set parameters relating to data.
wave_length = 16
seed_weighting = 0.6
seed_amount = int(len(data[1]) * seed_weighting)

# Make predictions and set predictions.
outputs, states = make_predictions(data, seed_amount, wave_length)

# Plot results with annotations.
annotations=[
    dict(x=wave_length-2, y=1.05, ax=-10, ay=-40, text='Develop model state.'),
    dict(x=wave_length*3-10, y=-1.05, ax=20, ay=40, text='Continue presenting source signal.'),
    dict(x=wave_length*5+3, y=-1.05, ax=20, ay=40, text="Feedback the model's previous predictions."),
    dict(x=wave_length*7, y=1.1, ax=-120, ay=-30, text="Click legend to show/hide elements."),
]
make_prediction_plot(data, outputs, wave_length, seed_amount, 'Square Wave Prediction Task (Simple)', annotations, show_prediction=False)

From selecting the prediction trace in the legend, we can see that the model has definitely captured and can reproduce the signal; but again, how can we tell what information the network is storing during each timestep? One common approach towards starting to answer this question is to plot the individual cell activation state values for each timestep [3], which we do as follows:

In [6]:
# Simple square wave state activations.

_, _, node_activations = get_node_activations(states, seed_weighting)
annotations=[
    dict(x=wave_length*2+5, y=4.2, showarrow=False, text="Observe how node 6 predicts the signal's repitition."),
    dict(x=wave_length-1, y=-1, ax=-40, ay=90, text='Here,'),
    dict(x=wave_length*2-1, y=-1, ax=-40, ay=90, text='and here,'),
    dict(x=wave_length*3-1, y=-1, ax=-40, ay=90, text='and here...'),
    dict(x=wave_length*7, y=4, ax=-120, ay=-30, text="Double-click node to show/hide the rest."),
]
make_state_activation_plot(data, node_activations, 'State Activations for Square Wave Network (Simple)', annotations, color=6)

If you look closely for patterns, one can see that during every occurrence of the wave switching from high to low, the waves tend to exhibit the same behaviour. These happen every 8 timesteps as the period of our wave is 16 and there is one switch after 8 samples. From this we gain some understanding that our model is in fact capturing useful information in its states. But of course, we currently know the dynamics of our wave and what to look for. If we were unsure of patterns in our network, these dynamics would be hard to elicit through this simplistic visualisation. In fact, before we explore more complex examples, let us first examine how the difficulty of noticing these patterns changes with a more realistically sized network.

Large Network with Square Wave

Since the sizes of LSTM networks tend to be much larger than 4 nodes per layer, we now consider how the same cell activation visualisation aids in the understanding of a larger network with 32 nodes per layer. Below are the same two plots, shown for the smaller square wave network, but for the results of the new larger network:

In [7]:
# Simple square wave (large) prediction.

# Generate data and load model.
data = normalise_wave(make_square(steps=16, tiles=7, offset=0))
model = WaveLSTM(1, 2, 32)
model.load_state_dict(pickle.load(open('simple_square_32.p', 'rb')))

# Make predictions and set predictions.
outputs, states = make_predictions(data, seed_amount, wave_length)

# Plot results with annotations.
annotations=[
    dict(x=6, y=-.8, ax=100, ay=40, text='The model has no context yet.'),
    dict(x=wave_length*3-3, y=1.1, ax=-50, ay=-40, text='The larger model better fits the wave pattern.'),
]
make_prediction_plot(data, outputs, wave_length, seed_amount, 'Square Wave Prediction Task (Large)', annotations)
In [8]:
# Simple square wave state activations.

_, activations_reduced, node_activations = get_node_activations(states, seed_weighting)
annotations=[
    dict(x=wave_length*2, y=4.1, showarrow=False, text='Observe how the nodes cross over at the switch points.'),
    dict(x=wave_length*2, y=-.3, ax=-30, ay=135, text='Here,'),
    dict(x=wave_length*3-8, y=-.3, ax=0, ay=135, text='and here,'),
    dict(x=wave_length*3, y=-.3, ax=30, ay=135, text='and here...'),
]
make_state_activation_plot(data, node_activations, 'State Activations for Square Wave Network (Large)', annotations)

From the visualisation of the 64 cell activation states (includes both layers) of the larger square wave network, it is already becoming much harder to notice the dynamic patterns that are indicative of an upcoming switch in the signal. Not only that, but the amount of nodes in modern LSTM networks can be in the range of thousands. Since it is harder to notice patterns in the states of larger networks, this visualisation does not suffice even if we know the dynamics present in the input signal. In order to alleviate this, perhaps we could create a more manageable representation of the states that scales well with the amount of nodes present.

State Space Representations

When dealing with visualising the cell activation states in these networks, we basically have a problem of dimensionality. The more nodes we have in the network, the more states we need to examine; and the more states we need to examine, the harder it is notice patterns simply by looking at how the states themselves change over time. A simple solution to this problem is to use a dimensionality reduction technique to minimise the amount of visual clutter such that patterns captured by the networks’ states are easier to identify. That then raises the question, what are the appropriate number of dimensions to maximally reduce visual clutter whilst still eliciting useful visual patterns?

To start, let us consider what the usefulness of 2 dimensions would be. If only 2 dimensions were still able to capture interpretable information about the internal states of the network, we could then create a two dimensional scatter plot. A particular location in this scatter plot would represent a specific network state. If these reduced dimensions do in fact contain useful state information, we would expect closer points in space to represent when the network is remembering similar information. Conversely, if two points are further apart the network should be storing vastly different information. How do we verify this whilst still having no method of actually understanding the states at a fundamental level?

Well, if these premises hold true, then this ‘state space representation’ would allow us to know if the network is progressing towards a state that we’ve encountered before. We still do not know what the states contain exactly, but we would then be able to compare our current state with historical ones. If we intuitively know something about our input signal, such as our square wave, then we already know which states should be in similar locations. Since the wave is repeating and doesn’t change across its periods, the same offset position in every period should ideally be storing the same information in anticipation of which part of the wave is coming next. For example, when the square wave it about to switch from low to high, the network needs to keep track of how long it has been since the drop from high to low, and its states should be reflective of that. So for every switch from low to high (or high to low, or any position for that matter) the network should be capturing that same information, which implies that their positions in our state space should be closer together.

Square Wave State Space

Now that we have an idea of what state space representations are, how they could potentially be useful, and how to validate that they are actually representing something useful, let us now apply the concept to our larger square wave model. First we reduce the dimensionality of the cell activations for all of our nodes across the entirety of the network using Principal Component Analysis (PCA). We only train the PCA model on the states during the presentation and seeding phases of our training paradigm, as the feedback phase is representative of future predictions. The following visualisation is the result:

In [9]:
# Simple square wave (Large) PCA state activation plot.

annotations = [
    dict(x=wave_length*2-2, y=10, showarrow=False, text='Observe how we can now differenciate two switch types.'),
    dict(x=wave_length*2+.1, y=2.9, ax=-30, ay=120, text='High to low.'),
    dict(x=wave_length*3-7.2, y=-5, ax=35, ay=40, text='Low to high.'),
]
make_reduced_state_activation_plot_2D_by_position(data, activations_reduced, wave_length, 'PCA State Activations for Square Wave Network (Large)', annotations)

Since we know the dynamics of the repeating square wave, we can label each point with its relative position within its period. This allows us to see whether or not the same position along the wave has similar values for each of the components. Additionally, we can plot the same two components in our state space scatter plot:

In [10]:
# Simple square wave (Large) 2D state space plot.

annotations = [
    dict(x=3.5, y=3.4, ax=-20, ay=50, text='High to low switch points.'),
    dict(x=-5.1, y=-4, ax=30, ay=-50, text='Low to high switch points.'),
]
make_state_space_plot_2D_by_position(data, activations_reduced, wave_length, seed_weighting, 'PCA of Cell Activations for Square Wave Network (Large) by Wave Position', annotations)

From this state space plot, we can see that our assumptions held true. After the initial presentation phase, the network settled on a consistent travel pattern where similar positions along the wave had similar positions in the plot. Additionally, this holds true even during the feedback phase with minor amounts of deviation. This suggests that the model contains an internal representation of the dynamics of the wave, and only uses the input for initially syncing with the phase of the wave and correcting minute deviations. If we transform our plot into 3 dimensions, where the third axis is the time step, we can see this clearer:

In [11]:
# Simple square wave (Large) state space 3D plot.

annotations = [
    dict(x=3.38, y=3.78, z=96, ax=0, ay=-100, text='High to low switch points.'),
    dict(x=-5.05, y=-4.4, z=105, ax=0, ay=-100, text='Low to high switch points.'),
]
make_state_space_plot_3D_by_position(data, activations_reduced, wave_length, 'PCA of Cell Activations for Square Wave Network (Large) by Wave Position', annotations)

Even though this representation is useful for a simple pattern such as a square wave, is it capable of differentiating between multiple patterns in the same input signal?

State Space of Alternating Waves

From the two previous visualisations one can see that the network travels a consistent circular path during each period. Given our assumptions, if you consider the states of a model trained on two alternating waves each with their own dynamics, then our state space plot should contain two distinct traveling paths, one for each wave. To test this we train a network of the same architecture on a more complex repeating signal. The signal starts with 2 cycles of the square wave as before, then transitions to 2 cycles of a triangle wave with only 8 samples per wave instead of the 16 of the square wave. Using the same prediction setup as before we obtain the following results:

In [12]:
# Alternating wave prediction.

# Generate data and load model.
data = attach_waves(make_square(steps=16, tiles=2, offset=0), make_triangle(steps=8, tiles=2, offset=0), tiles=7, offset=0)
model = WaveLSTM(1, 2, 32)
model.load_state_dict(pickle.load(open('square_triangle.p', 'rb')))

# Update parameters to match new wave.
wave_length = 48
seed_amount = int(len(data[1]) * seed_weighting)

# Make predictions and set predictions.
outputs, states = make_predictions(data, seed_amount, wave_length)

# Plot results.
make_prediction_plot(data, outputs, wave_length, seed_amount, 'Alternating Wave Task')

From this we can see that the model was again able to learn how to predict the input signal extremely well. Let us now see how our reduced representation changes when presented with more complex dynamics:

In [13]:
# Alternating Wave PCA state activation plot.

_, activations_reduced, node_activations = get_node_activations(states, seed_weighting)

# Plot results with annotations.
annotations = [
    dict(x=wave_length+8, y=12, ax=-20, ay=-50, text='Square wave patterns.'),
    dict(x=wave_length*2-7, y=1, ax=60, ay=-130, text='Triangle wave patterns.'),
]
make_reduced_state_activation_plot_2D_by_position(data, activations_reduced, wave_length, 'PCA State Activations for Alternating Network by Wave Position', annotations)

Again we can see a clear distinction between the different positions within the square wave, however we can also see differences between the square wave cycles and the triangle wave cycles. Below we can also see that those differences are also reflected in the state space plot:

In [14]:
# Alternating wave network 2D state space plot.

annotations = [
    dict(x=-4, y=7.5, ax=250, ay=20, text='First low to high switch points.'),
    dict(x=-3.9, y=6.5, ax=250, ay=50, text='Second low to high switch points.'),
    dict(x=9.95, y=-5.4, ax=-20, ay=-50, text='First high to low switch points.'),
    dict(x=8.95, y=-5.7, ax=-180, ay=-130, text='Second high to low switch points.'),
    dict(x=-3.6, y=-2, ax=150, ay=-10, text='First triangle wave peaks.'),
    dict(x=-3.9, y=-1, ax=40, ay=-50, text='Second triangle wave peaks.'),
]
make_state_space_plot_2D_by_position(data, activations_reduced, wave_length, seed_weighting, 'PCA of Cell State Activations for Alternating Network by Wave Postion', annotations)

Transforming the plot into 3 dimensions again makes the differences more evident:

In [15]:
# Alternating wave 3D state space plot.

annotations = [
    dict(x=-3.93, y=6.5, z=311, ax=0, ay=-100, text='Low to high switch points.'),
    dict(x=9.02, y=-5.94, z=319, ax=0, ay=-100, text='High to low switch points.'),
    dict(x=-3.93, y=-1.24, z=332, ax=0, ay=-100, text='Triangle wave peaks.'),
]
make_state_space_plot_3D_by_position(data, activations_reduced, wave_length, 'PCA of Cell Activations for Alternating Network by Wave Position', annotations)

From our state space representation of the alternating wave scenario, we have provided evidence that this method for visualising and comparing the internal states of LSTM networks extends beyond a single wave pattern. That being said, this scenario still remains rather contrived and unrealistic. Additionally, since we control the period of the waves and when they switch, we can aid in the visualisation process by coloring each state space point with its position in the wave. In more realistic data sources, we do not have that information and thus need a different method for verifying the similarity between different nearby states.

Real-World Applications

To identify the utility of this approach in a realistic context, we will use a dataset of water usage in the student accommodation buildings of the University of the West of England. The readings were recorded from water meters in four separate buildings. Each building contains approximately 5-7 water meters. Readings were taken every half an hour over a two semester period. We first separated the dataset into a training and test portion by separating out all of the meters from one of the buildings into the test set. The model is then tasked with providing hourly predictions for a single meter signal. During training, a small 2-day window, randomly selected from the available water meters, is presented to the model as a single example.

The idea is that the model should be able to learn the normal patterns of student water usage, learn to ignore the day to day noise, and the differences in meter locations. For example, if you calculate the average value by hour for each building, you obtain a pattern similar to this:

In [16]:
# Average pattern for one meter.

# Load the water data with an average example.
data = pandas.read_csv('../../data/water_data/water_data.csv')
meter = data[data['meter'] == 'Brecon 1']
meter = meter.set_index('date')
meter.index = pandas.to_datetime(meter.index)
meter = meter.resample('1h').sum()
meter = meter.groupby(meter.index.hour)['value'].mean()

# Plot average data with annotations.
traces = make_traces_2d([dict(
    x=numpy.array(meter.index),
    y=numpy.array(meter),
    mode='lines',
)])
annotations = [
    dict(x=5, y=0.15, ax=-20, ay=-50, text='Early morning dip.'),
    dict(x=12, y=0.5, ax=-20, ay=50, text='Lunchtime peak.'),
    dict(x=20, y=0.45, ax=-20, ay=50, text='Evening peak.'),
]
figure = make_figure_2d(traces, 'Average Water Usage for Single Meter', 'Hour', 'Value', annotations)
plotly.iplot(figure)

As can be expected there is a significant drop in water usage in the early hours of the morning, as well as two spikes around mid-afternoon and in the evening. Even though any particular day on a specific meter can deviate drastically from this pattern, there is still an underlying distribution of usage which is common across all of the meter locations. This is some of the information we should expect the model to be capturing. Perhaps there are other dynamics which become more evident in the state space representation, but first let us start with examining how well our network performs over a 7-day period:

In [17]:
# Water data predictions.

# Generate water data example and load model.
data = make_water_data(7)
model = WaveLSTM(1, 2, 32)
model.load_state_dict(pickle.load(open('water_data3.p', 'rb')))

# Update parameters to match water data.
seed_weighting = 0.85
wave_length = 24
seed_amount = int(len(data[1]) * seed_weighting)

# Make predictions and set predictions.
outputs, states = make_predictions(data, seed_amount, wave_length)

# Plot results with annotations.
annotations = [
    dict(x=datetime.datetime(2017, 10, 30, 5, 0), y=-0.9, ax=-30, ay=30, text='Early morning dip.'),
    dict(x=datetime.datetime(2017, 10, 30, 11, 30), y=0.9, ax=0, ay=-40, text='Lunchtime peak.'),
    dict(x=datetime.datetime(2017, 10, 30, 23, 0), y=0.6, ax=70, ay=-60, text='Evening peak.'),
]
make_prediction_plot(data, outputs, wave_length, seed_amount, 'Water Data Task', annotations)

Next, the initial method for visualising the model’s states:

In [18]:
# Plot of water data states.

_, activations_reduced, node_activations = get_node_activations(states, seed_weighting)
annotations = [
    dict(x=datetime.datetime(2017, 10, 30, 5, 0), y=-1.2, ax=-30, ay=40, text='Early morning dip.'),
    dict(x=datetime.datetime(2017, 10, 30, 11, 30), y=0.7, ax=0, ay=-40, text='Lunchtime peak.'),
    dict(x=datetime.datetime(2017, 10, 30, 23, 0), y=0.5, ax=70, ay=-60, text='Evening peak.'),
]
make_state_activation_plot(data, node_activations, 'State Activations for Water Data Network', annotations)