Exploring A Simple Attention Mechanism for Multi-omic Microbiomics
Overview
The following serves as an introduction to attention models and an exploration of the utility of these models for multi-omics through an example analysis of a Type II Diabetes (T2D) vs. Control experiment. I consider a main metagenomic dataset and a subordinate blood serum dataset using a modified version of this model. The model used here is pictured below.
A simple attention model with three key steps: (1) A linear layer mapping each feature in each dataset to modules. A module is encoded as a two-dim vector. (2) A cosine similarity attention mechanism describing similarity between modules, and (3) a fully connected prediction layer.Understanding the Model
Fitting the Model
Each module is a weighted linear combination of the the input data. Intuitively, each module is just a two dimensional vector (pictured below). Consider we are encoding $M_x$ and $M_y$ modules for datasets $X \in \mathbb{R}^{N \times D_x}$ and $Y \in \mathbb{R}^{N \times D_y}$, respectively. If $\mathbf{x} \in \mathbb{R}^{D_x}$ is one sample from one of the datasets (e.g., metagenomics). The modules $\mathbf{M}(\mathbf{x}) \in \mathbb{R}^{M_x \times 2}$ are two dimensional, normalized, weighted linear combinations of the observed data:
$$ \mathbf{M}(\mathbf{x}) = \begin{bmatrix}\mathbf{W^{(0)}}\mathbf{x} & \mathbf{W^{(1)}}\mathbf{x}\end{bmatrix}, \qquad \mathbf{M}(\mathbf{x})\_{i,\cdot} \leftarrow \frac{\mathbf{M}(\mathbf{x})\_{i,\cdot}} {\left\lVert \mathbf{M}(\mathbf{x})\_{i,\cdot} \right\rVert_2} $$where $\mathbf{W^{(0)}}, \mathbf{W^{(1)}} \in \mathbb{R}^{M_x \times D_x}$ are learned weight matrices for the two-dim module vectors and $||\mathbf{z}||_2$ is the l2-norm.
Attention is calculated from the cosine similarity between two module vectors $\mathbf{M}(x)_{i,\cdot}$ and $\mathbf{M}(y)_{i,\cdot}$ (pictured below). Since the vectors are normalized, cosine similarity is the dot product:
$$ \cos(\theta) = \mathbf{M}(\mathbf{x})\_{i,\cdot} · \mathbf{M}(\mathbf{y})^\top\_{i,\cdot} $$Example Modules and Cosine Similarity: Four metagenomics modules and one blood serum module are shown. Each module can be represented as a vector; attention is based on the cosine similarities between the metagenomic and blood serum modules. The angle $\theta$ and cosine similarity between modules B and E are shown.Modules with larger cosine similarities are given higher attention. The attention matrix $\mathbf{A} \in \mathbb{R}^{M_x \times M_y}$ is just the cosine similarities but with a column softmax applied. Intuitively, one column of the attention matrix represents the proportion of attention one blood serum module assigns to all the metagenomic modules.
$$ \mathbf{A} = \\text{softmax}\_{\text{col}}\left(\mathbf{M}(\mathbf{x}) · \mathbf{M}(\mathbf{y})^\top\right) $$The attention-weighted modules $\mathbf{T} \in \mathbb{R}^{M_y \times 2}$ are a attention-weighted average of all the metagenomic modules for one blood serum module (pictured below):
$$ \mathbf{T} = \mathbf{A}^\top \textbf{M}(\textbf{x}) $$Example Cont. Attention Matrix & Weighted Modules: The attention matrix $\textbf{A}$ (right-most gray-scale matrix) represents attention weights blood serum module E assigns to metagenomic modules $A$-$D$. The metagenomic module vectors $\textbf{M}(\textbf{x})$ from the previous plot are shown again in the blue-scale matrix. Through the relationship $\mathbf{T} = \mathbf{A}^\top \textbf{M}(\textbf{x})$, the attention-weighted module $F$ (purple-scale) is obtained. Module vectors $A$-$F$ are depicted.Module Feature Importances
Each module is a weighted, linear combination of the learned feature vectors. However, modules are normalized: only their direction matters. Yet feature vectors with larger magnitudes influence module direction more. Feature importance measures the strength of the feature along the module’s final direction. Intuitively, feature importance represents how strongly a feature ‘voted’ in the module’s final direction (or the opposite direction for negative importance). Formally, it is the feature’s magnitude projected onto the module direction. For feature $i$, module $a$, importance $I(i, a)$ is:
$$ I(i, a) = \left\lVert x\_{i} \times \mathbf{w}(i, a) \right\rVert_2 \times \cos(\theta) $$where $\mathbf{w}(i,a)$ is the weights connecting feature $i$ to module $a$ and $\theta$ is the angle between the feature and module vectors. (Equivalently, $I(i, a) = \left(x_{i} \times \mathbf{w}(i,a)\right) · \mathbf{M}(\mathbf{x})_{a,\cdot}$).
Three example feature importances for metagenomic module $D$. (Left Plot) The importance for feature $i$ (green vector) for module $D$ (blue vector) is $I(i, D)$ (black vector). Feature vector $i$ has a large magnitude and points in a similar direction to module $D$, resulting in the large positive importance $I(i, D)=2.53$. (Middle Plot). Feature vector $j$ has a large magnitude but is almost orthogonal to $D$, resulting in a small importance of $I(j,D)=-0.18$. (Right Plot) Feature vector $k$ has a moderate magnitude pointing roughly in the opposite direction of $D$, resulting in a moderate negative importance of $I(k,D)=-0.5$.Type II Diabetes Dataset Results
We’ll fit the attention model to this Metacardis dataset (see Zenodo and Github for data/analysis) with 154 control samples and 85 T2D samples. The metagenomic data measured the absolute abundances of 59 genera with shotgun metagenomic sequencing and flow cytometry. The blood serum data measured 29 metabololites with Ultra-Performance Liquid Chromatography–Mass Spectrometry (UPLC–MS).
Four modules were used for the metagenomic and blood serum datasets. The fitted model had an AUC of 0.97 and an accuracy of 0.93 on the held out test set.
Modules and Attention Results
The module encodings and cosine-similarity attention matrices for the trained model are shown below. All values shown are the mean values across the training set (encodings/attention values differ across samples).
Feature Importances
The importances for each module and for each blood serum metabolite / metagenomic genus are shown below. (Note: only three metabolites with absolute importance $>=0.1$ in at least one blood serum module and only 28 genera with absolute importance $>=1$ in at least one metagenomic module are shown).
Interpretation
The metagenomic modules A-D and features are interpreted for each individual blood serum module E-H:
Module H: In T2D, blood glucose tends to be elevated while glutamine is depleted. Module H captures this pattern: when glucose signal is strong (T2D), it attends to metagenomic modules A/D, and when weak (control), it attends to modules B/C. Control-associated modules B/C identify genera with beneficial gut-health associations including butyrate producers Faecalibacterium and Coprococcus, and genera linked to improved insulin secretion/sensitivity: Akkermansia and Bacteroides. In modules A/D in T2D, beneficial genera Roseburia and Bifidobacterium have negative importances, while genera with positive importances include Bilophila (promotes intestinal barrier dysfunction leading to glucose dysmetabolism), Collinsella (decreases liver glycogenesis and increases triglyceride synthesis), Ruminococcus (promotes inflammatory cytokine production leading to insulin resistance), and Streptococcus (related to inflammation).
Module E: This module attends to metagenomic modules A/D in both Control and T2D, and aligns with glucose and lactic acid. It identifies the distinct genera in modules A/D that cause these modules to point in somewhat different directions in Control vs. T2D. In Control, for instance, modules A/D place more importance on some known beneficial genera including Coprococcus and Odoribacter.
Module F: Lactic acid can be elevated in individuals with T2D due to, for instance, decreased blood flow in adipose tissue (common in obesity). This module attends to metagenomic modules A/D when lactic acid is sufficiently elevated (i.e., in T2D), but the signal is destroyed when overridden by glutamine (i.e., in Control).
Module G: Not particularly well aligned with any blood serum metabolites.
Discussion
The following are some loosely associated concluding points:
Reproducibility: I am skeptical of the reproducibility of these results. Other methods (e.g., sGCCA-based MintTea) use data subsampling to derive consensus modules. Perhaps dropout in the linear encoding of the modules could be similarly used. Furthermore, training on multiple datasets across studies could help.
Some Surprising Results: Some genera do not match expectations. For example, Eubacterium, although a beneficial Butyrate producer, unexpectedly has negative importance in Control modules B/C and positive importance in T2D module D. Three potential hypotheses for this: (1) this analysis used absolute abundances while many studies consider relative abundances. Reanalyzing the data using relative abundances (i.e., CLR transformed abundances) might produce results more concordant with prior studies. (2) The model suffered from limited expressibility due to only using 4 modules and only two module directions. (3) This may just be a reproducibility issue: a study specific result.
Improving Interpretation: While not performed here, an enrichment analysis on each module using annotated microbe sets could help interpret each module. However, using known sets defeats the purpose of identifying novel sets.
Code
The full code is provided here on GitHub, but the main model code is:
1class MultiOmicsModuleAtt(nn.Module):
2 def __init__(self, feature_dim_a, feature_dim_b, num_modules_a,
3 num_modules_b, hidden_dim, out_dim=1):
4 """
5 Args:
6 feature_dim_a: Da (# features dataset xa)
7 feature_dim_b: Db (# features dataset xb)
8 num_modules_a: Number modules desired for dataset xa
9 num_modules_b: Number modules desired for dataset xb
10 hidden_dim: Final layer hidden dim
11 out_dim: Final output dim (1 for binary classification)
12 """
13 super().__init__()
14 self.num_modules_a = num_modules_a
15 self.num_modules_b = num_modules_b
16 self.encoder_ax = nn.Linear(feature_dim_a, num_modules_a,
17 bias=False)
18 self.encoder_ay = nn.Linear(feature_dim_a, num_modules_a,
19 bias=False)
20 self.encoder_bx = nn.Linear(feature_dim_b, num_modules_b,
21 bias=False)
22 self.encoder_by = nn.Linear(feature_dim_b, num_modules_b,
23 bias=False)
24 self.fc1 = nn.Linear((num_modules_b) * 2, hidden_dim)
25 self.fc2 = nn.Linear(hidden_dim, out_dim)
26
27 def forward(self, xa, xb):
28 """
29 Args:
30 xa: Dominant dataset matrix
31 (N, Da) for N (# samples), Da (# features)
32 xb: Subordinate data set (N, Db)
33 """
34 # Encode (x, y) vector coordinates for input datasets xa, xb
35 enc_ax = self.encoder_ax(xa)
36 enc_ay = self.encoder_ay(xa)
37 enc_bx = self.encoder_bx(xb)
38 enc_by = self.encoder_by(xb)
39 # Concatenate (x, y) vector coordinate values
40 enc_a = torch.stack([enc_ax, enc_ay], dim=2)
41 enc_b = torch.stack([enc_bx, enc_by], dim=2)
42 # L2 Normalization of each module vector
43 enc_a = F.normalize(enc_a, p=2, dim=2)
44 enc_b = F.normalize(enc_b, p=2, dim=2)
45 # Raw Attention Matrix (raw cosine similarities)
46 raw_attn = enc_a @ (enc_b.transpose(1, 2))
47 # Softmax Attention matrix (probabilities)
48 attn_a = F.softmax(raw_attn, dim=1)
49 # Attention-encoded modules for xa
50 att_enc_a = attn_a.transpose(1, 2) @ enc_a
51 # Flatten attention-encoded modules, and predict
52 att_enc_a_flat = att_enc_a.view(-1, self.num_modules_b*2)
53 out = self.fc1(att_enc_a_flat)
54 out = F.relu(out)
55 out = self.fc2(out)
56 return out.squeeze(1)