The Platonic Representation Hypothesis
Introduction
This year, one of the most impactful papers delves into the fascinating idea that AI models are converging toward a shared understanding of reality, much like Plato’s concept of an ideal world. What makes this paper stand out is how it moves away from the typical structure of algorithm proposals, mathematical proofs, and experiments, and instead takes a more philosophical approach. It explores the deeper implications of AI’s growing alignment across models and modalities, suggesting that as these systems evolve, they might be inching closer to a universal representation of the world. For me, what’s truly inspiring about this paper is how it opens up fresh discussions about the future of AGI (Artificial General Intelligence) by focusing on the bigger picture of how machines are starting to grasp reality in a way that's surprisingly aligned with human cognition. It’s thought-provoking and shines a light on the long-term trajectory of AI development.
Planotic Hypothesis
Plato proposes the allegory of the cave, describing that individuals are confined, facing a wall and seeing only shadows cast by objects behind them, which they believe to be reality. These shadows symbolize the limited reality perceived through our senses, contrasting with the true forms of objects. Here in the paper's hypothesis, neural networks trained on different data and for different tasks are converging towards a shared universal model of reality in how they process and represent information.
This figure visually explains the concept of different data modalities (images and text) as projections from a common underlying reality (Z). The hypothesis posits that these converging structures reflect a deeper, universal model of reality, regardless of the specific tasks or data they handle.
Representation Alignment
Specifically, alignment means the growing similarity in how data points are represented across different models. And representation is a function $f : \mathcal{X} \rightarrow \mathbb{R}^n$ that assigns a feature vector to each input in some data domain $\mathcal{X}$. Representational alignment is a measure of the similarity of the similarity structures induced by two representations, which is, a similarity metric over kernels. A kernel, $K : \mathcal{X} \times \mathcal{X} \rightarrow \mathbb{R}$, characterizes how a representation measures similarity between datapoints, i.e., $K(x_i, x_j) = \langle f(x_i), f(x_j)\rangle$, where $\langle \cdot, \cdot\rangle$ denotes inner product. A kernel-alignment metric, $m : \mathcal{K}\times \mathcal{K} \rightarrow \mathbb{R}$ measures the similarity between two kernels.
The paper explored different ways in which representations are converging.
Different Neural Networks are Converging
There are some past works where they measured representational similarity through a technique called $\textbf{model stitching}$. where two models with multiple layers are integrated together via an affine stitching layer $h$. Given two different models $f$ and $g$, each composed of multiple layers, where $f=f_1 \circ ... \circ f_n$, $g=g_1 \circ ... \circ g_m$, the new stitched model $F = f_1 \circ ... \circ f_k \circ h \circ g_{k+1} \circ ... \circ g_m$ is an intermediate representation got by affining $f$ into $g$ via $h$. If $F$ still has good performance, it indicates that $f$ and $g$ have compatible representations at layer k, up to the transform $h$. There's a study conducted by Lenc & Vedaldi that a vision model trained on ImageNet can be aligned with a model trained on Places-365 while maintaining good performance, reflecting the convergence of different neural networks.
Alignment Increases with Scale and Performance
As AI models grow in scale and complexity, their internal representations increasingly align, especially in terms of performance. The paper demonstrates through experiments that models with higher capacity and stronger performance tend to exhibit more aligned representations. By analyzing 78 vision models using mutual nearest-neighbors on the Places-365 dataset, the researchers found that models performing well on downstream tasks, such as those from the Visual Task Adaptation Benchmark (VTAB), are more aligned with each other compared to weaker models. This trend highlights that as models become larger and better at handling more tasks, their internal structures converge, even across diverse architectures and training objectives. Essentially, the better a model performs, the more it “thinks” like other high-performing models, supporting the idea that increased scale leads to a more universal way of processing and representing data. This convergence could be a crucial step towards more generalized AI systems.
Representations are Converging Across Modalities
The paper highlights how AI models are not only converging within single modalities but also across different ones, such as vision and language. As models become larger and more capable, their representations from diverse data types begin to align.
The experimental results in this figure highlights how language and vision models are becoming more aligned as they improve in performance. Using the Wikipedia caption dataset, the researchers found that there is a clear linear relationship: the better a language model performs, the more closely its representations align with those of high-performing vision models. Interestingly, CLIP models, which are trained with explicit language supervision, show even stronger alignment between the two modalities. However, this alignment drops when CLIP is fine-tuned for ImageNet classification, suggesting that task-specific training can affect cross-modal alignment. This experiment further supports the idea that as models grow more capable, their internal representations begin to converge, even across different types of data.
Alignment Improves Downstream Performance
Through experiments, this paper also shows, as models become more aligned, their performance on downstream tasks improves.
The above figure illustrates this by plotting the alignment of language models with a vision model (DINOv2) against their performance on common sense reasoning (Hellaswag) and math problem-solving (GSM8K). The results reveal a clear trend: models that are more aligned with vision models tend to perform better on these tasks. In particular, there’s a linear relationship between alignment and performance on common sense reasoning, and an "emergence" pattern for math tasks, where performance significantly increases after reaching a certain level of alignment. This demonstrates that alignment directly enhances a model's ability to tackle complex real-world tasks.
The Reason for Representations Converging
Modern machine learning models are generally trained to minimize the empirical risk with a regularization :
The authors laid out how each colored component in this optimization process potentially plays a role in facilitating representational convergence.
Task Generality
The green part in the above formula corresponds to task generality. As data scales, models that optimize the empirical risk $\mathbb{E}_{x \sim dataset} [L(f, x)]$ also improve on the population risk $\mathbb{E}_{x \sim reality} [L(f, x)]$, and become better at capturing statistical structures of the true data generating process. And here’s a visualization of the fact that models trained with an increasing number of tasks are subjected to pressure to learn a representation that can solve all the tasks.
$\textbf{The Multitask Scaling Hypothesis}$ There are fewer representations that are competent for N tasks than there are for M < N tasks. As we train more general models that solve more tasks at once, we should expect fewer possible solutions.
The conclusion of this hypothesis is that scaling up the diversity of tasks during training naturally drives the model to develop more robust, generalized representations, ultimately enhancing its ability to perform well across a range of complex, real-world applications.
Model Capacity
The purple part of the formula corresponds to the model capacity, the authors argues that larger AI models with greater computational resources are more likely to converge towards a universal representation of data, leading to enhanced performance across a variety of tasks.
This figure illustrates this by showing a relationship between the size or complexity of models and their performance metrics, the underlying idea is that increased capacity allows for more intricate data processing, enabling the model to capture a broader range of patterns in the data, which in turn facilitates better generalization and alignment with the 'Platonic ideal' of a universal representation.
Simplicity Bias
The red part (regularization) of the formula, corresponds to the convergence via simplicity bias.
$\textbf{The Simplicity Bias Hypothesis}$ Deep networks are biased toward finding simple fits to the data, and the bigger the model, the stronger the bias. Therefore, as models get bigger, we should expect convergence to a smaller solution space.
The Simplicity Bias Hypothesis suggests that AI models, particularly when trained with standard optimization techniques like gradient descent, will favour simpler or more predictable representations of data over the complex ones. This figure would typically illustrate this by showing how simpler models can sometimes achieve competitive performance on tasks by focusing on the most predictable and generalizable features within the data. Hence, this hypothesis indicates that this simplicity bias drives models to converge towards more generalizable solutions, which can explain why different models often align in their representations, even when trained on diverse tasks. This bias helps ensure that models remain flexible and perform well across various tasks, despite the complexity of the underlying data.
What representation are we converging to?
So we’ve discussed about the convergence a lot, but what representation are we converging to? The author argues that these models are increasingly approximating a representation of reality.
Here the authors used some mathematical computation to support this hypothesis. The world could be denoted as a sequence of discrete events sampled from a distribution P, and they try to prove by contrastive learners that the observations converge to the reality distribution P, and each event can be observed in various ways, the observation corresponds to a bijection. Then the author shows how the representation of P(Z) is recovered by certain contrastive learners.
Here's the full proof procedures with my personal interpretations:
The co-occurrence probability between two observations $x_a$ and $x_b$ is defined based on their proximity within a time window:
$P_{coor}(x_a, x_b) \propto \sum_{(t, t^{\prime}:|t-t^{\prime}|\le T_{window})}\mathbb{P}(X_t = x_a, X_{t\prime} = x_b)$
If the time difference between $x_a$ and $x_b$ is less than or equal to $T_{window}$, they are considered to co-occur (a $\textbf{positive pair}$), meaning they happen close together in time. And in contrast, a $\textbf{negative pair}$ means observations sampled independently from the marginal.
The objective for the contrastive learner is to learn a way to represent these observations such that the similarity between the representations of the observations (measured by the dot-product of their representations) approximates the log odds ratio of them being a positive pair versus a negative pair.
$\langle f_X(x_a), f_X(x_b)\rangle \approx log\frac{\mathbb{P}(pos|x_a, x_b)}{\mathbb{P}(neg|x_a, x_b)} + \tilde{c}_X(x_a)\\ = log\frac{P_{coor}(x_a|x_b)}{P_{coor}(x_a)} + c_X(x_a)\\ = K_{PMI}(x_a, x_b) + c_X(x_a)$
where $K_{PMI}$ is the pointwise mutual information (PMI) kernel, and $c_X(x_a)$ is constant in $x_b$. The PMI kernel, measures the mutual information between the two observations—essentially quantifying how much knowing one of the observations reduces uncertainty about the other. We can observe that $c_X(x_a)$ from the equation must be a constant since both sides in the equation are symmetric. Under mild conditions that the world is smooth enough, a choice of $f_X$ can exactly represent $K_{PMI}$:
$\langle f_X(x_a), f_X(x_b)\rangle = K_{PMI}(x_a, x_b) + c_X$
Therefore, the results of the contrastive learners will converge to a representation $f_X$ whose kernel is $K_{PMI}$.
Thus we have convergence to a representation of the statistics of X, but note that it hasn't been proved that it's converge to the reality Z. So what about Z?
Recall that our idealized world consists of bijective observation functions, which, over discrete random variables, preserve probabilities, hence :
$P_{coor}(x_a, x_b) = P_{coor}(z_a, z_b)\\ K_{PMI}(x_a, x_b) = K_{PMI}(z_a, z_b)$
All these arguments hold not just for X but also for Y (or any other bijective, discrete modality), implying:
$K_{PMI}(z_a, z_b) = \langle f_X(x_a), f_X(x_b) \rangle - c_{X}\\ = \langle f_Y(y_a), f_Y(y_b) \rangle - c_{Y}$
Therefore, for any modality in our idealized world, representations convergence to the same kernel, which represents certain pairwise statistics of $\mathbb{P}(Z)$.
Implications of the Convergence to Real World
- First, scaling alone is not sufficient. Different methods scale at varying levels of efficiency, and to succeed, they must meet some core requirements, such as being consistent estimators or accurately modeling the pairwise statistics of the underlying reality $P(Z)$.
- Second, training data can be shared across modalities. As models converge across modalities, it becomes easier to train them using data from different domains, which enhances versatility.
- Third, cross-modal adaptability becomes more seamless. When models share a common modality-agnostic representation, they can readily adapt to new modalities, making it easier to translate knowledge from one domain to another.
- Lastly, scaling may reduce hallucinations and bias. As models converge towards a more accurate representation of reality, increasing their scale could help diminish hallucinations and biases by aligning them more closely with the real world.
Take-aways
I feel somehow this paper could give a hint or guidance of the growth and research directions of lots of ventures. Here're some examples :
- Unified Architecture : Focus on developing or adapting architectures that demonstrate strong performance across multiple tasks and modalities, reducing the need for specialized solutions.
- Data Efficiency : By leveraging models that converge towards universal representations, we can enhance data efficiency, requiring less labeled data from specific domains to achieve high performance.
- Integrated Data Solutions : Create solutions that integrate different types of data inputs to provide richer insights, essential for complex decision-making processes in business environments.
- Infrastructure Development : Invest in computational infrastructure that supports the training and deployment of large-scale models.
- Long-term AI Strategy : Formulate a long-term strategy focusing on creating and refining AI models that align with the convergence trends identified in the research.