OLA-VLM: Elevating Visual Perception in Multimodal LLMs with Auxiliary Embedding Distillation

Jitesh Jain1,2*         Zhengyuan Yang2         Humphrey Shi1†         Jianfeng Gao2†         Jianwei Yang2†

1 SHI Labs @ Georgia Tech     2 Microsoft Research, Redmond    
*Work done during JJ's internship at Microsoft Research, Redmond.     Equal Advising

Paper

           

Code

           

Demo

           

Citation



TLDR

We propose distilling target visual information into the intermediate representations of the LLM from a set of target encoders. We adopt a predictive embedding optimization approach at selected LLM layers during training to minimize the embedding losses along with the next token prediction (NTP) objective, resulting in a vision-centric approach to training the Multimodal Large Language Model. We only use a single base vision encoder during inference, resulting in compute efficient visual encoding along with better visual perception.
Descriptive text about the image

Abstract

The standard practice for developing contemporary MLLMs is to feed features from vision encoder(s) into the LLM and train with natural language supervision. In this work, we posit an overlooked opportunity to optimize the intermediate LLM representations through a vision perspective (objective), i.e., solely natural language supervision is sub-optimal for the MLLM's visual understanding ability. To that end, we propose OLA-VLM, the first approach distilling knowledge into the LLM's hidden representations from a set of target visual representations. Firstly, we formulate the objective during the pretraining stage in MLLMs as a coupled optimization of predictive visual embedding and next text-token prediction. Secondly, we investigate MLLMs trained solely with natural language supervision and identify a positive correlation between the quality of visual representations within these models and their downstream performance. Moreover, upon probing our OLA-VLM, we observe improved representation quality owing to the embedding optimization. Thirdly, we demonstrate that our OLA-VLM outperforms the single and multi-encoder baselines, proving our approach's superiority over explicitly feeding the corresponding features to the LLM. Particularly, OLA-VLM boosts performance by an average margin of up to 2.5% on various benchmarks, with a notable improvement of 8.7% on the Depth task in CV-Bench.


Visually Probing MLLM Representations

Representation Quality's Effect on Performance

We establish a relationship between the visual representation quality inside the LLM and the observed downstream performance through a series of probing experiments against features from a set of target encoders:

(a) We observe that increasing the amount of data and training solely with the next-token prediction objective enhances the visual representation quality within the LLM, resulting in improved performance, underscoring the effectiveness of our probing setup.

(b) OLA-VLM exhibits superior visual representations and performance compared to LLaVA-1.5 under the same settings, demonstrating the effectiveness of minimizing the predictive embedding objective during training.
Introductory plots

Key Insights

Probe Plots

Layerwise Trend: The middle (12-24) layer probes show the best representation quality for the depth and seg probing tasks, with an upward trend in quality in the initial layers and a downward trend in the deeper layers. We attribute the fairly high cosine similarity (greater than 0.7) for all layers with some irregularities to the choice of CLIP-based target encoder during gen probing.

Visual Encoding Approach: The probes for the multi-encoder MLLM learn better representations than the probes for the single-encoder MLLM. We also train probes on our OLA-VLM and observe that they fall between the two baselines, serving as a good trade-off between efficiency and visual representation accuracy.

Training Data: We observe that with an increase in the amount of training data for the probed model, the probes show gradual improvement, indicating that the representations of the visual world inside MLLMs and, consequently, performance on downstream tasks improve with just natural language supervision on more data!


Embedding Visual Information into LLM

arch

During Pre-Training (PT), we optimize an embedding loss at specific layers for each target encoder: layers \(d \in \mathbb{D}\), \(s \in \mathbb{S}\), and \(g \in \mathbb{G}\) for the depth, segmentation, and generation tasks, respectively. We use a resampler-based embedding predictor, denoted as \(\mathbf{P}^{l}_\text{\{task\}}\) at each layer \(l\), to output predictions. Each predictor takes in two inputs: a set of learnable queries (\(\mathbf{Q}^\text{\{task\}}\)) and the token sequence from layer \(l\), with special tokens for other tasks omitted. The final loss is the sum of embedding losses across all selected layers and the next-token prediction objective. During IFT, we train with only the next-token prediction objective while keeping the special tokens frozen so as not to affect their task-specific nature.


Results

arch

We present results across different base encoders and decoder LLMs. Our OLA-VLM outperforms the single encoder and multi encoder LLaVA-1.5 by up to 2.5% and 0.9% on average across various benchmarks, respectively. The best numbers are set in bold for every base-encoder and decoder LLM combination. Please check our paper 📄 for extensive experiments and ablation studies.


Citation

If you found our work useful in your research, please consider starring us on GitHub and citing 📚 us in your research!

@article{jain2024ola_vlm,
      title={{OLA-VLM: Elevating Visual Perception in Multimodal LLMs with Auxiliary Embedding Distillation}},
      author={Jitesh Jain and Zhengyuan Yang and Humphrey Shi and Jianfeng Gao and Jianwei Yang},
      journal={arXiv},
      year={2024}
  }