We propose distilling target visual information into
the intermediate representations of the LLM from a set of target encoders. We adopt a predictive embedding optimization approach
at selected LLM layers during training to minimize the embedding losses along with the next token prediction (NTP) objective, resulting
in a vision-centric approach to training the Multimodal Large Language Model. We only use a single base vision encoder during inference, resulting in compute efficient visual encoding along with better visual perception.
Abstract
The standard practice for developing contemporary MLLMs is to feed features from vision encoder(s) into the LLM and train with natural language supervision. In this work, we posit an overlooked opportunity to optimize the intermediate LLM representations through a vision perspective (objective), i.e., solely natural language supervision is sub-optimal for the MLLM's visual understanding ability. To that end, we propose OLA-VLM, the first approach distilling knowledge into the LLM's hidden representations from a set of target visual representations. Firstly, we formulate the objective during the pretraining stage in MLLMs as a coupled optimization of predictive visual embedding and next text-token prediction. Secondly, we investigate MLLMs trained solely with natural language supervision and identify a positive correlation between the quality of visual representations within these models and their downstream performance. Moreover, upon probing our OLA-VLM, we observe improved representation quality owing to the embedding optimization. Thirdly, we demonstrate that our OLA-VLM outperforms the single and multi-encoder baselines, proving our approach's superiority over explicitly feeding the corresponding features to the LLM. Particularly, OLA-VLM boosts performance by an average margin of up to 2.5% on various benchmarks, with a notable improvement of 8.7% on the Depth task in CV-Bench.
Visually Probing MLLM Representations
Representation Quality's Effect on Performance
We establish a relationship between the visual representation quality inside the LLM and the observed downstream performance through a series of probing experiments against features from a set of target encoders:
(a) We observe that increasing the amount of data and training solely with the next-token prediction objective enhances the visual representation quality within the LLM, resulting in improved performance, underscoring the effectiveness of our probing setup.
(b)OLA-VLM exhibits superior visual representations and performance compared to LLaVA-1.5 under the same settings, demonstrating the effectiveness of minimizing the predictive embedding objective during training.
Key Insights
Layerwise Trend:The middle (12-24) layer probes show the best representation quality for the depth and seg probing tasks, with an upward trend in quality in the initial layers and a downward
trend in the deeper layers. We attribute the fairly high cosine similarity (greater than 0.7) for all layers
with some irregularities to the choice of CLIP-based target encoder during gen probing.
Visual Encoding Approach: The probes for the multi-encoder MLLM learn better representations than the probes for the single-encoder MLLM. We
also train probes on our OLA-VLM and observe that they
fall between the two baselines, serving as a good trade-off
between efficiency and visual representation accuracy.
Training Data: We observe that with an increase in the
amount of training data for the probed model, the probes
show gradual improvement, indicating that the representations of the visual world inside MLLMs and, consequently,
performance on downstream tasks improve with just natural
language supervision on more data!
Embedding Visual Information into LLM
During Pre-Training (PT), we optimize an embedding loss at specific layers for each target encoder:
layers \(d \in \mathbb{D}\), \(s \in \mathbb{S}\), and \(g \in \mathbb{G}\) for the depth, segmentation,
and generation tasks, respectively. We use a resampler-based embedding predictor,
denoted as \(\mathbf{P}^{l}_\text{\{task\}}\) at each layer \(l\), to output predictions.
Each predictor takes in two inputs: a set of learnable queries (\(\mathbf{Q}^\text{\{task\}}\)) and
the token sequence from layer \(l\), with special tokens for other tasks omitted.
The final loss is the sum of embedding losses across all selected layers and the next-token prediction objective.
During IFT, we train with only the next-token prediction objective while keeping the special tokens frozen so as
not to affect their task-specific nature.
Results
We present results across different base encoders and decoder LLMs. Our OLA-VLM outperforms the single encoder and multi encoder LLaVA-1.5 by up to 2.5% and 0.9%
on average across various benchmarks, respectively. The best numbers are set in bold for every base-encoder and decoder LLM combination. Please check our paper 📄 for extensive experiments and ablation studies.
Citation
If you found our work useful in your research, please consider starring
⭐ us on
GitHub
and citing 📚 us in your research!
@article{jain2024ola_vlm,
title={{OLA-VLM: Elevating Visual Perception in Multimodal LLMs with Auxiliary Embedding Distillation}},
author={Jitesh Jain and Zhengyuan Yang and Humphrey Shi and Jianfeng Gao and Jianwei Yang},
journal={arXiv},
year={2024}
}