Paper: arXiv
Authors: Adam Santoro, David Raposo, David G.T. Barrett, Mateusz Malinowski, Razvan Pascanu, Peter Battaglia, Timothy Lillicrap
Date of Publication: 5th June 2017
This paper notably made its way into Ilya Sutskever’s top 30 papers to read. That’s where I found this classic.

Overview

Authors introduce Relational Networks (RNs), just like how CNNs are implicitly biased by having spatial and translation invariance, RNs have the capacity to compute relational reasoning baked into their architecture. In their simplest form the module composes of a component that computes relations between all input ‘objects’ and another component which further process the representations of the relations between all input ‘objects’. All pair-wise relations between 2 objects are considered. So a RN must learn relations (reminds me of something…). Authors achieve super human performance on a benchmark and get very good results across all the benchmarks that they test.

where the input is a set of objects {}, where is the object. and are two functions with parameters and . computes the pairwise relations between all the objects and processes those representations. In the paper, and are simple MLPs.

Relational Networks

Few notable strengths of RNs include:

RNs learn to infer relations

It’s clear from the functional form of RN that they consider relations of all object pairs. Thus, RNs learn to infer the existence and implications of all object relations.

RNs are data efficient

RNs use a single function to compute each relation. This encourages generalization as is not encouraged to over-fit any particular object-object pair. Considering the case of an MLP, an MLP would receive a set of all objects as input, and it must learn all relations within its weights. It quickly becomes intractable as grows. MLP attempts to handle all relations in a single forward pass.
But in the case of RN, is used to explicitly consider each pair of objects separately, instead of processing all relations at once like a simple MLP. RNs are designed to learn a single, general relation function. This function is applied repeatedly to each pair of objects in the set.

RNs operate on a set of objects

The summation in the definition ensures that RNs are invariant to the order at which they process the objects. This invariance ensures that the RN’s output contains information that is generally representative of the relations that exist in the object set and is not affected by the order in which objects appear.

Tasks

Authors primarily focus on 4 tasks to test RNs.

CLEVR

CLEVR contains images of 3D-rendered objects, such as spheres and cylinders. Each image is associated with a number of questions that fall into different categories. For example, query attribute questions may ask “What is the color of the sphere?”, while compare attribute questions may ask “Is the cube the same material as the cylinder?”. An important feature of CLEVR is that many questions are explicitly relational in nature.

An example demonstrating a problem of CLEVR




Authors use a CNN to process the image. CNN brings the image to kernel each of dimension. Authors use each of -dimensional cells as an object (look at the image below, it’s way more clear). They tokenize the questions such that each word is a token, and pass the token embeddings through an LSTM and use the final state of LSTM as question representation for the questions. They concatenate the question representation with the object-pair relations representation. MLPs are used as and . The image below explains entire architecture really well.

Model architecture used to achieve super human performance on CLEVR


They obtain super human performance on this task.

Performance of RNs on CLEVR

bAbI

bAbI is a pure text-based QA dataset. There are 20 categories of questions, each category corresponds to a particular type of reasoning, such as deduction, induction, or counting. Each question is associated with a set of supporting facts (support set). For example, the facts “Sandra picked up the football” and “Sandra went to the office” support the question “Where is the football? ” (answer: “office”). A model passes the category if it scores above 95%.

Authors use positional encoding to tag the order of sentences at which they appear in the support set. Then they processed each sentence word-by-word using an LSTM. The last state of LSTM is the representation for that sentence, forming an object to feed into RN. Authors use a single LSTM to process all the support set sentences. Similar to CLEVR, a separate LSTM is used to process the question sentence.

Model succeeds on 18/20 tasks.

There are 2 other tasks, ‘Sort-of-CLEVR’ a task made by authors explicitly to test relational reasoning, and a task involving a simulation of balls moving, and RN had to make some predictions based on the current and previous states of the simulation.
RN performed well on all the tasks explored in the paper.

Previously it was believed that due to a lack of good methods to represent the questions, the models performed poorly in these tasks. But RNs do good in these tasks and all the other models lacking relational reasoning do pretty bad, especially in CLEVR. Which made authors conclude that relational reasoning is the cause for their performance gains.

Thoughts

Relational Networks are closely related to transformers with attention, where attention map is the pair-wise ‘object relations’. But instead of simply summing them up, we softmax and multiply by . Relational Networks can be seen as a restricted form of attention based networks.
Field came a very long way since 2017, feels like I read a relic of the past.