Masked Contrastive Representation Learning for Reinforcement Learning

被引:29
作者
Zhu, Jinhua [1 ]
Xia, Yingce [2 ]
Wu, Lijun [2 ]
Deng, Jiajun [1 ,3 ]
Zhou, Wengang [1 ,4 ]
Qin, Tao [2 ]
Liu, Tie-Yan [2 ]
Li, Houqiang [1 ,4 ]
机构
[1] Univ Sci & Technol China, CAS Key Lab GIPAS, Hefei 230027, Peoples R China
[2] Microsoft Res, Beijing 100080, Peoples R China
[3] Univ Sci & Technol China, CAS Key Lab GIPAS, Hefei, Peoples R China
[4] Hefei Comprehens Natl Sci Ctr, Inst Artificial Intelligence, Hefei, Peoples R China
基金
中国国家自然科学基金;
关键词
Transformers; Training; Task analysis; Representation learning; Convolutional neural networks; Image reconstruction; Games; Contrastive learning; reinforcement learning; transformer;
D O I
10.1109/TPAMI.2022.3176413
中图分类号
TP18 [人工智能理论];
学科分类号
081104 ; 0812 ; 0835 ; 1405 ;
摘要
In pixel-based reinforcement learning (RL), the states are raw video frames, which are mapped into hidden representation before feeding to a policy network. To improve sample efficiency of state representation learning, recently, the most prominent work is based on contrastive unsupervised representation. Witnessing that consecutive video frames in a game are highly correlated, to further improve data efficiency, we propose a new algorithm, i.e., masked contrastive representation learning for RL (M-CURL), which takes the correlation among consecutive inputs into consideration. In our architecture, besides a CNN encoder for hidden presentation of input state and a policy network for action selection, we introduce an auxiliary Transformer encoder module to leverage the correlations among video frames. During training, we randomly mask the features of several frames, and use the CNN encoder and Transformer to reconstruct them based on context frames. The CNN encoder and Transformer are jointly trained via contrastive learning where the reconstructed features should be similar to the ground-truth ones while dissimilar to others. During policy evaluation, the CNN encoder and the policy network are used to take actions, and the Transformer module is discarded. Our method achieves consistent improvements over CURL on 14 out of 16 environments from DMControl suite and 23 out of 26 environments from Atari 2600 Games. The code is available at https://github.com/teslacool/m-curl.
引用
收藏
页码:3421 / 3433
页数:13
相关论文
共 60 条
[1]  
Afouras T., 2018, IEEE T PATTERN ANAL, DOI DOI 10.1109/TPAMI.2018.2889052
[2]  
Agarwal Rishabh, 2021, Advances in Neural Information Processing Systems, V34
[3]  
Alec RadfordKarthik Narasimhan., 2018, IMPROVING LANGUAGE U
[4]  
Ba J L., LAYER NORMALIZATION
[5]   The Arcade Learning Environment: An Evaluation Platform for General Agents [J].
Bellemare, Marc G. ;
Naddaf, Yavar ;
Veness, Joel ;
Bowling, Michael .
JOURNAL OF ARTIFICIAL INTELLIGENCE RESEARCH, 2013, 47 :253-279
[6]  
Berner C., DOTA 2 LARGE SCALE D
[7]  
Bousmalis K, 2018, IEEE INT CONF ROBOT, P4243
[8]  
Brown TB, 2020, ADV NEUR IN, V33
[9]   End-to-End Object Detection with Transformers [J].
Carion, Nicolas ;
Massa, Francisco ;
Synnaeve, Gabriel ;
Usunier, Nicolas ;
Kirillov, Alexander ;
Zagoruyko, Sergey .
COMPUTER VISION - ECCV 2020, PT I, 2020, 12346 :213-229
[10]  
Chen M, 2020, PR MACH LEARN RES, V119