L'attention n'est pas tout ce dont vous avez besoin

Alors que je travaillais sur MacGraph, un réseau de neurones qui répond aux questions à l’aide de graphiques de connaissances, je suis tombé sur un problème: comment savoir si quelque chose ne figurait pas dans une liste.

Dans cet article, nous allons vous montrer notre solution, le signal de focalisation, et montrer comment elle fonctionne de manière fiable sur une gamme de différents jeux de données et architectures.

Contexte

En programmation traditionnelle, il est facile de savoir si quelque chose ne figure pas dans une liste: exécutez une boucle for sur les éléments, et si vous arrivez à la fin sans la trouver, alors ce n’est pas là. Cependant, les réseaux de neurones ne sont pas si simples.

Les réseaux de neurones sont composés de fonctions différentiables, de sorte qu'ils peuvent être formés en utilisant la descente de gradient. Les opérateurs d'égalité, pour les boucles et si les conditions, les éléments standard de la programmation traditionnelle utilisés pour résoudre cette tâche, ne fonctionnent pas très bien dans les réseaux de neurones:

  • Un opérateur d'égalité est essentiellement une fonction pas à pas, qui a un gradient nul presque partout (et rompt donc la rétro-propagation de la descente de gradient)
  • Si les conditions utilisent généralement un signal booléen pour changer de branche, ce qui est encore une fois le résultat d'une fonction pas à pas problématique
  • Bien que les boucles puissent être inefficaces sur les GPU, et parfois même inutiles, les bibliothèques de réseaux neuronaux exigent que toutes les données aient les mêmes dimensions (par exemple, TensorFlow s'exécute sur un graphe tenseur défini de manière statique).

Une technique populaire de réseau de neurones pour travailler avec des listes d'éléments (par exemple, traduire des phrases les traitant comme des listes de mots) consiste à appliquer «l'attention». Il s'agit d'une fonction dans laquelle une «requête» apprise de ce que recherche le réseau est comparée à chaque élément de la liste, et une somme pondérée des éléments similaires à la requête est générée:

Dans l'exemple ci-dessus,

  1. La requête est produite par points avec chaque élément de la liste pour calculer un «score». Ceci est fait en parallèle pour tous les articles
  2. Les scores sont ensuite passés à travers softmax pour les transformer en une liste dont la somme est 1.0. Les scores peuvent ensuite être utilisés comme distribution de probabilité.
  3. Enfin, une somme pondérée des éléments est calculée, en pondérant chaque élément par son score.

L'attention a été très fructueuse et constitue la base des meilleurs modèles de traduction actuels. Le mécanisme a particulièrement bien fonctionné car:

  • C’est rapide et simple à mettre en œuvre
  • Comparé à un réseau neuronal récurrent (par exemple, LSTM), il est beaucoup plus capable de se référer aux valeurs passées de la séquence d'entrée. Un LSTM doit apprendre à conserver séquentiellement les valeurs passées ensemble dans un seul état interne lors de plusieurs itérations RNN, tandis que l'attention peut rappeler les valeurs de séquence passées à n'importe quel moment du passage.
  • De nombreuses tâches peuvent être résolues en réorganisant et en combinant des éléments de liste pour former une nouvelle liste (par exemple, les modèles d’attention ont été des composants importants dans de nombreux modèles de traduction, de réponse aux questions et de raisonnement de pointe).

Notre problème

En dépit de la polyvalence et du succès de l’attention, elle présente une lacune qui a nui à notre travail de réponse aux questions graphiques: l’attention ne nous indique pas si un élément est présent dans une liste.

C'est ce qui s'est passé pour la première fois lorsque nous avons tenté de répondre à des questions telles que «Existe-t-il une station appelée London Bridge?» Et «La gare de Trafalgar Square est-elle adjacente à la gare de Waterloo?». Nos tables de noeuds et de bords de graphiques ont toutes ces informations à extraire, mais l’attention en elle-même n’a pas permis de déterminer l’existence d’un élément.

Cela se produit car attention renvoie une somme pondérée de la liste. Si la requête correspond (par exemple un score élevé) à un élément de la liste, le résultat sera presque exactement cette valeur. Si la requête ne correspond à aucun élément, une somme de tous les éléments de la liste est renvoyée. Sur la base des résultats de l’attention, le reste du réseau ne peut pas facilement différencier ces deux situations.

Notre solution

La solution simple que nous proposons consiste à générer un agrégat scalaire des scores bruts de requête d'élément (par exemple avant d'utiliser softmax). Ce signal sera faible si aucun élément n'est similaire à la requête et élevé si de nombreux éléments le sont.

En pratique, cela a été très efficace (en fait, la seule solution robuste parmi les nombreuses que nous avons testées) pour résoudre des questions d’existence. A partir de maintenant, nous nous référerons à ce signal comme "focus". Voici une illustration du réseau d’attention que nous avons montré précédemment, avec le signal de focalisation ajouté:

Attention avec signal de focalisation, utilisant une somme pour l'agrégation des scores bruts

Montage expérimental

Constatant que le signal de focalisation était essentiel pour que MacGraph puisse mener à bien certaines tâches, nous avons testé ce concept sur une gamme de jeux de données et d'architectures de modèle.

Le réseau de neurones de l’expérience

Nous avons construit un réseau qui prend une liste d'éléments et un élément souhaité (la «requête»), et indique si cet élément était dans la liste. Le réseau prend les entrées, effectue une surveillance (éventuellement avec notre signal de focalisation), transforme les sorties en plusieurs couches résiduelles, puis génère une distribution binaire¹ indiquant si l'élément a été trouvé.

La perte de réseau est calculée à l'aide d'une entropie croisée softmax et est formée à l'aide de l'optimiseur Adam. Le taux d’apprentissage idéal de chaque variante de réseau est déterminé avant la formation à l’aide du détecteur de taux d’apprentissage.

Dans nos expériences, nous faisons varier le réseau en:

  • Inclusion / suppression du signal de mise au point de l'étape d'attention («Utiliser la mise au point» ci-dessous)
  • Varier la fonction de score d'attention
  • Faire varier la fonction d'agrégation de signal de mise au point
  • Faire varier la fonction d’activation de la première couche de transformation résiduelle («Activation de la sortie» ci-dessous)

Nous appliquons ensuite ces variations de réseau à quelques jeux de données différents, listés dans la section suivante.

Voici toutes les configurations de réseau que nous avons testées:

Le signal de focalisation s'est avéré essentiel pour que le réseau atteigne avec fiabilité une précision supérieure à 99% sur l'ensemble des jeux de données et des configurations de réseau.

Le code de nos expériences est open-source dans notre GitHub.

Ensembles de données

Chaque jeu de données est un ensemble d'exemples, chaque exemple contient les entités en entrée Liste d'éléments, Élément souhaité et la sortie de vérité au sol, L'élément est dans la liste. Un élément est un vecteur à N dimensions de nombres à virgule flottante.

Chaque jeu de données a été construit de sorte qu'une précision de 100% soit possible.

Nous avons testé sur trois jeux de données différents, chacun avec une source d'éléments différente:

  • Vecteurs one-hot orthogonaux de longueur 12
  • Plusieurs vecteurs chauds (par exemple des chaînes aléatoires de 1,0 s et 0,0 s) de longueur 12
  • Word2vec vecteurs de longueur 300

Chaque ensemble de données a des classes de réponses équilibrées (c'est-à-dire un nombre égal de réponses True et False)

Les vecteurs one-hot et many-hot ont été générés aléatoirement. Chaque exemple de formation comporte une liste de deux éléments. Chaque jeu de données contient 5 000 éléments de formation. Voici l'un des exemples de formation:

Un exemple de formation dans le jeu de données orthogonal one-hot. (Vecteur de longueur 3 pour la simplicité visuelle)

Pour notre ensemble de données word2vec, nous avons utilisé des intégrations Word2vec pré-calculées formées sur Google Actualités, qui peuvent être téléchargées ici.

Le jeu de données word2vec contient des listes d'éléments d'une longueur maximale de 77, chacune représentant une phrase mélangée, choisie au hasard dans un article wikipedia de 300 lignes. Chaque élément souhaité est un vecteur word2vec choisi au hasard parmi les mots contenus dans l'article.

Résultats

Le signal de focalisation s'est avéré essentiel pour déterminer de manière fiable l'existence d'un élément:

  • Sans signal de mise au point, aucune configuration réseau obtenue avec une précision supérieure à 56% pour tous les jeux de données.
  • Pour les vecteurs de jeu de données orthogonaux, aucun réseau n'obtient une précision> 80% sans le signal de focalisation
  • Les réseaux avec le signal de focalisation ont atteint une précision> 98% dans tous les jeux de données et toutes les configurations, les réseaux sans signal de focalisation n'ont atteint que> 98% dans 4 configurations sur 9

Il convient de noter que le réseau sans signal de focalisation a atteint une précision supérieure à 98% dans certaines configurations et ensembles de données.

Conclusion

Dans cet article, nous avons présenté un nouveau concept de réseau d’attention, le «signal de mise au point». Nous avons montré qu’il s’agissait d’un mécanisme robuste permettant de détecter l’existence d’éléments dans une liste, ce qui est une opération importante pour le raisonnement de la machine et la réponse aux questions. Le signal de focus a détecté avec succès l'existence de l'élément dans tous les jeux de données et les configurations réseau testés.

Nous espérons que ce travail aidera d'autres équipes à relever le défi de l'existence d'un élément et ajoutera plus de clarté à la détermination de l'architecture à utiliser.

Les recherches d’Octavian

La mission d’Octavian est de développer des systèmes avec des capacités de raisonnement au niveau humain. Nous croyons que les données graphiques et l'apprentissage en profondeur sont des ingrédients clés pour rendre cela possible. Si vous souhaitez en savoir plus sur nos recherches ou contribuer, consultez notre site Web ou contactez-nous.

Remerciements

Je voudrais remercier David Mack pour l’idée initiale et pour l’aide apportée à la majeure partie de la rédaction de cet article. J'aimerais également remercier Andrew Jefferson d'avoir écrit une grande partie du code et de m'avoir guidé tout au long de cet article. Enfin, j'aimerais remercier Octavian de m'avoir donné l'occasion de travailler sur ce projet.

Notes de bas de page

  1. Les couches de transformation de sortie ont été copiées à partir de MacGraph et pourraient éventuellement être simplifiées. La couche sigmoïde (σ) n’est absolument pas nécessaire dans ce modèle car elle limite la plage des entrées à softmax et empêche donc le modèle de s’entraîner à perte zéro (bien qu’elle puisse toujours atteindre une précision de 100% puisque le résultat est argmax ' en un vecteur à chaud et comparé à l’étiquette). Nous ne pensons pas que cette complexité invalide l’un des résultats, mais nous nous sommes sentis obligés de l’annoncer.