QAST: A Dataset of Tensor Programs Execution Times

Skill LevelArea of FocusPlatform/Hardware
Intermediate, AdvancedArtifical IntelligenceArtificial Intelligence

Introduction:

Current deep learning frameworks like PyTorch or TensorFlow, allow for optimizing a computational graph representation. However, they do not tackle optimization of hardware-specific operator-level transformations, but rather rely on manually tuned and vendor-specific operator libraries.

Recently, this gap has been filled by TVM[1], a compiler framework that allows both graph-level and operator-level optimization in an end-to-end manner. For a given target hardware, each operator defines a schedule configuration space and TVM can compile the resulting tensor program and measure its execution time on the target hardware. This results in a hard optimization problem, in some GPU cases making the search space of a single conv2d operator consists of more than 106 configurations.

Current efforts overcome this issue by learning how to optimize tensor programs from data rather than heuristics. When considering tensor programs as data, the abstract syntax tree (AST) representation associated with an operator configuration provides a rich input space. Graph Neural Network (GraphNN) models are a good fit to work with AST as they preserve the graph structure allowing information propagation among nodes.

Objective:

Here, we present the QAST dataset that was used to support the experiments in our workshop paper at ICLR 2019: Simulating Execution Time of Tensor Programs Using Graph Neural Networks[2]. We hope this new dataset will benefit the graph research community and raise interest in Optimizing Compiler research.

Description

We collect data for the conv2d operators defined in a ResNet18. The ResNet18 architecture defines twelve unique convolution workloads, i.e., there are twelve unique parameterizations (aka workloads) for the convolution 2D operators in the network. We show these in the following table:

WorkloadHWCinCoutkernelstridepaddingdilation# configs
C1224224364(7,7)(2,2)(3,3)(1,1)252
C2 56 56 64 64 (3, 3) (1, 1) (1, 1) (1, 1) 784
C3 56 56 64 64 (1, 1) (1, 1) (1, 1) (1, 1) 784
C4 56 56 64 128 (3, 3) (2, 2) (1, 1) (1, 1) 672
C5 56 56 64 128 (1, 1) (2, 2) (1, 1) (1, 1) 672
C6 28 28 128 128 (3, 3) (1, 1) (1, 1) (1, 1) 768
C7 28 28 128 256 (3, 3) (2, 2) (1, 1) (1, 1) 576
C8 28 28 128 256 (1, 1) (2, 2) (1, 1) (1, 1) 576
C9 14 14 256 256 (3, 3) (1, 1) (1, 1) (1, 1) 648
C10 14 14 256 512 (3, 3) (2, 2) (1, 1) (1, 1) 360
C11 14 14 256 512 (1, 1) (2, 2) (1, 1) (1, 1) 360
C12 7 7 512 512 (3, 3) (1, 1) (1, 1) (1, 1) 400

For example, for workload 'C1', the convolution has 3 input channels of size 224 x 224, and it has 64 output channels. We also show the size of the configuration space (#configs) for each conv2d workload for the x86-64 target. All measurements were performed under TVM v0.5 on an Intel Xeon CPU E5-1620 v4 @ 3.50GHz.

Dataset:

This dataset is intended for research purposes only and to support and contribute to the graph research community. The quality of the configuration space design and the collected execution times may be suboptimal and should not be considered as reference performances of the target device but rather as representative of the problem at hand for research purposes.

Click on this link to download the dataset

Please cite our paper if you use this dataset in your research:

  @article{tomczak2019simulating,
    title={{Simulating Execution Time of Tensor Programs using Graph Neural Networks}},
    author={Tomczak, Jakub M and Lepert, Romain and Wiggers, Auke},
    journal={arXiv preprint arXiv:1904.11876},
    year={2019}
  }

Content

The dataset is separated in 12 sub-directories corresponding to the 12 unique conv2d workloads. For each workload configuration, we extract the corresponding tensor program AST that is represented by:

  • A node_ids vector of shape (#nodes,) where each id corresponds to a node type (e.g., for/if/else statements, variable name, etc.)
  • A node_features matrix of shape (#nodes, #features) where the feature extraction procedure follows the procedure for ‘loop context features’ presented in [3].
  • An edges matrix of shape (#edges,2) where the first columns contains source node indices and the second column contains destination node indices.
  • A curve_features vector of shape (#curve-features,) which is a fixed-size encoding of the entire AST to ‘context relation features’, as described in [3]. We use this as a baseline for comparison with GraphNN-based feature extraction.

For example, here is a tensor program and its associated AST graph:

produce C {
  for (i, 0, 8) {
    for (j., 0, 8) {
      C[((i*8) + j)] = 0.000000f
      for (k, 0, 8) {
        C[((i*8) + j)] = (C[((i*8) + j)] + (A[((i*8) + k)]*B[(j + (k*8))]))
      }
    }
  }
} 

Each configuration has measurement values associated with it:

  • A cost float32 scalar corresponding to the execution time averaged over 5 runs, where a run consists of taking 4 measurements and averaging these to mitigate measurement stochasticity.
  • An error_no int8 scalar corresponding to potential compilation or run error, where 0 means no error.

Following is the directory structure:

/QAST-x86
/C1  # conv2d for layer C1 workload
/C2  # conv2d for layer C2 workload
# ...
/C12  # conv2d for layer C12 workload
    /metadata.json  # dataset C12 info
    /costs.npy  # ndarray of shape (#configs,)
    /error_nos.npy  # ndarray of shape (#configs,)
    /1  # features for config 1
    /2  # features for config 2
    # ...
    /399  # features for config 399
        /edges.npy  # ndarray of shape (#edges,2)
        /node_ids.npy  # ndarray of shape (#nodes,)
        /node_features.npy  # ndarray of shape (#nodes,#node-features)
        /curve_features.npy  # ndarray of shape (#curve-features,)

For each workload we also provide a metadata.json file containing information about the produced workload:

{
  "cpu": {  # CPU info
    "arch": "X86_64",
    "brand": "Intel(R) Xeon(R) CPU E5-1620 v4 @ 3.50GHz",
    ...
  },
  "task": {  # TVM task info
    "name": "topi.nn.conv2d",
    "args": [...],
    "target": "llvm",
    "template_key": "direct"
  },
  "vocabulary": {...}, # dict of node-type to node-id
}

Usage

Here is a suggested way to create a data loader using Deep Graph Library (dgl) to work with graphs:

import json
from pathlib import Path

import numpy as np
import dgl


class Dataset:
    def __init__(self, path):
        self.path = path

        self._costs = np.load(path / 'costs.npy')
        self._error_nos = np.load(path / 'error_nos.npy')

        with (path / 'metadata.json').open('r') as f:
            self.attrs = json.loads(f.read())

    def __len__(self):
        """Return number of records in the dataset"""
        return len(self._costs)

    def __getitem__(self, idx):
        """Get record number i

        Arguments
        ---------
        idx: int
            Index of record

        Return
        ------
        graph: dgl.DGLGraph
        cost: np.float32
        error_no: int

        """
        dir = self.path / '{}'.format(idx)
        node_features = np.load(dir / 'node_features.npy')  # ndarray of shape (#nodes, #features)
        node_ids = np.load(dir / 'node_ids.npy')  # ndarray of shape (#nodes,)
        edges = np.load(dir / 'edges.npy')  # ndarray of shape (#edges, 2)
        srcs, dsts = edges[:, 0], edges[:, 1]

        graph = dgl.DGLGraph()
        graph.add_nodes(len(node_features))
        graph.add_edges(srcs, dsts)
        graph.ndata['features'] = node_features
        graph.ndata['emb_idx'] = node_ids

        cost = self._costs[idx]
        error_no = self._error_nos[idx]

        return graph, cost, error_no


if __name__ == '__main__':
    # parameters
    path = Path('/path/to/QAST-x86/C1')
    dataset = Dataset(path)

    for i in range(len(dataset)):
        graph, cost, error_no = dataset[i]
    

If you want to directly work with TVM, you can retrieve the original TVM task which the dataset was generated from:

import tvm
from tvm.autotvm import task
from tvm.autotvm.task.topi_integration import deserialize_args
import topi

# register tvm template
@task.register("topi.nn.conv2d")
def topi_nn_conv2d(data, kernel, stride, padding, dilation, layout, dtype):
    data, kernel = deserialize_args([tuple(data), tuple(kernel)])
    C = topi.nn.conv2d(data, kernel, stride, padding, dilation, layout, dtype)
    s = topi.generic.schedule_conv2d_nchw([C])
    return s, [data, kernel, C]


if __name__ == '__main__':
    # get template args
    with open('/path/to/QAST-x86/C1/metadata.json', mode='r') as f:
        metadata = json.loads(f.read())
    task = metadata['task']

    # create task
    task = autotvm.task.create('topi.nn.conv2d',
        args=task['args'], target=task['target'], template_key=task['template_key'])

    # get schedule configuration for record i
    config = task.config_space.get(i)

License

This software is licensed by Qualcomm Technologies, Inc. under this Clear BSD License. If you are using the software on behalf of your employer or another legal entity, you agree to these terms on their behalf as well as on your own behalf, and you represent that you have the legal authority to bind such employer or other legal entity to these terms. If you do not have such authority or you or they do not agree to this Clear BSD License, you and such entity may not use this software and must delete all copies of it.

Copyright (c) 2019 Qualcomm Technologies, Inc.
All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:

  • Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer:
  • Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
  • Neither the name of Qualcomm Technologies, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

References

[1] Tianqi Chen, Thierry Moreau, Ziheng Jiang, Haichen Shen, Eddie Q Yan, LeyuanWang, Yuwei Hu, Luis Ceze, Carlos Guestrin, and Arvind Krishnamurthy. TVM: end-to-end optimization stack for deep learning. In [OSDI, pp. 578-594,2018] (https://www.usenix.org/conference/osdi18/presentation/chen)

[2] Jakub Tomczak, Romain Lepert, Auke Wiggers. Simulating Execution Time of Tensor Programs using Graph Neural Networks. In ICLR workshop Representation Learning on Graphs and Manifolds workshop. [arXiv preprint arXiv:1904.11876, 2019] (https://arxiv.org/abs/1904.11876)

[3] Tianqi Chen, Lianmin Zheng, Eddie Yan, Ziheng Jiang, Thierry Moreau, Luis Ceze, Carlos Guestrin, and Arvind Krishnamurthy. Learning to optimize tensor programs. In [NeurIPS, pp. 3393–3404, 2018] (https://papers.nips.cc/paper/7599-learning-to-optimize-tensor-programs)