°º¤ø,¸¸,ø¤ Exploring Biological Data ¸,ø¤º°`°º¤ø,¸

Exploring A Simple Attention Mechanism for Multi-omic Microbiomics

Overview


The following serves as an introduction to attention models and an exploration of the utility of these models for multi-omics through an example analysis of a Type II Diabetes (T2D) vs. Control experiment. I consider a main metagenomic dataset and a subordinate blood serum dataset using a modified version of this model. The model used here is pictured below.

A simple attention model with three key steps: (1) A linear layer mapping each feature in each dataset to modules. A module is encoded as a two-dim vector. (2) A cosine similarity attention mechanism describing similarity between modules, and (3) a fully connected prediction layer.

Understanding the Model


Fitting the Model


Each module is a weighted linear combination of the the input data. Intuitively, each module is just a two dimensional vector (pictured below). Consider we are encoding $M_x$ and $M_y$ modules for datasets $X \in \mathbb{R}^{N \times D_x}$ and $Y \in \mathbb{R}^{N \times D_y}$, respectively. If $\mathbf{x} \in \mathbb{R}^{D_x}$ is one sample from one of the datasets (e.g., metagenomics). The modules $\mathbf{M}(\mathbf{x}) \in \mathbb{R}^{M_x \times 2}$ are two dimensional, normalized, weighted linear combinations of the observed data:

$$ \mathbf{M}(\mathbf{x}) = \begin{bmatrix}\mathbf{W^{(0)}}\mathbf{x} & \mathbf{W^{(1)}}\mathbf{x}\end{bmatrix}, \qquad \mathbf{M}(\mathbf{x})\_{i,\cdot} \leftarrow \frac{\mathbf{M}(\mathbf{x})\_{i,\cdot}} {\left\lVert \mathbf{M}(\mathbf{x})\_{i,\cdot} \right\rVert_2} $$

where $\mathbf{W^{(0)}}, \mathbf{W^{(1)}} \in \mathbb{R}^{M_x \times D_x}$ are learned weight matrices for the two-dim module vectors and $||\mathbf{z}||_2$ is the l2-norm.

Attention is calculated from the cosine similarity between two module vectors $\mathbf{M}(x)_{i,\cdot}$ and $\mathbf{M}(y)_{i,\cdot}$ (pictured below). Since the vectors are normalized, cosine similarity is the dot product:

$$ \cos(\theta) = \mathbf{M}(\mathbf{x})\_{i,\cdot} · \mathbf{M}(\mathbf{y})^\top\_{i,\cdot} $$
Example Modules and Cosine Similarity: Four metagenomics modules and one blood serum module are shown. Each module can be represented as a vector; attention is based on the cosine similarities between the metagenomic and blood serum modules. The angle $\theta$ and cosine similarity between modules B and E are shown.

Modules with larger cosine similarities are given higher attention. The attention matrix $\mathbf{A} \in \mathbb{R}^{M_x \times M_y}$ is just the cosine similarities but with a column softmax applied. Intuitively, one column of the attention matrix represents the proportion of attention one blood serum module assigns to all the metagenomic modules.

$$ \mathbf{A} = \\text{softmax}\_{\text{col}}\left(\mathbf{M}(\mathbf{x}) · \mathbf{M}(\mathbf{y})^\top\right) $$

The attention-weighted modules $\mathbf{T} \in \mathbb{R}^{M_y \times 2}$ are a attention-weighted average of all the metagenomic modules for one blood serum module (pictured below):

$$ \mathbf{T} = \mathbf{A}^\top \textbf{M}(\textbf{x}) $$
Example Cont. Attention Matrix & Weighted Modules: The attention matrix $\textbf{A}$ (right-most gray-scale matrix) represents attention weights blood serum module E assigns to metagenomic modules $A$-$D$. The metagenomic module vectors $\textbf{M}(\textbf{x})$ from the previous plot are shown again in the blue-scale matrix. Through the relationship $\mathbf{T} = \mathbf{A}^\top \textbf{M}(\textbf{x})$, the attention-weighted module $F$ (purple-scale) is obtained. Module vectors $A$-$F$ are depicted.

Module Feature Importances


Each module is a weighted, linear combination of the learned feature vectors. However, modules are normalized: only their direction matters. Yet feature vectors with larger magnitudes influence module direction more. Feature importance measures the strength of the feature along the module’s final direction. Intuitively, feature importance represents how strongly a feature ‘voted’ in the module’s final direction (or the opposite direction for negative importance). Formally, it is the feature’s magnitude projected onto the module direction. For feature $i$, module $a$, importance $I(i, a)$ is:

$$ I(i, a) = \left\lVert x\_{i} \times \mathbf{w}(i, a) \right\rVert_2 \times \cos(\theta) $$

where $\mathbf{w}(i,a)$ is the weights connecting feature $i$ to module $a$ and $\theta$ is the angle between the feature and module vectors. (Equivalently, $I(i, a) = \left(x_{i} \times \mathbf{w}(i,a)\right) · \mathbf{M}(\mathbf{x})_{a,\cdot}$).

Three example feature importances for metagenomic module $D$. (Left Plot) The importance for feature $i$ (green vector) for module $D$ (blue vector) is $I(i, D)$ (black vector). Feature vector $i$ has a large magnitude and points in a similar direction to module $D$, resulting in the large positive importance $I(i, D)=2.53$. (Middle Plot). Feature vector $j$ has a large magnitude but is almost orthogonal to $D$, resulting in a small importance of $I(j,D)=-0.18$. (Right Plot) Feature vector $k$ has a moderate magnitude pointing roughly in the opposite direction of $D$, resulting in a moderate negative importance of $I(k,D)=-0.5$.

Type II Diabetes Dataset Results


We’ll fit the attention model to this Metacardis dataset (see Zenodo and Github for data/analysis) with 154 control samples and 85 T2D samples. The metagenomic data measured the absolute abundances of 59 genera with shotgun metagenomic sequencing and flow cytometry. The blood serum data measured 29 metabololites with Ultra-Performance Liquid Chromatography–Mass Spectrometry (UPLC–MS).

Four modules were used for the metagenomic and blood serum datasets. The fitted model had an AUC of 0.97 and an accuracy of 0.93 on the held out test set.

Modules and Attention Results


The module encodings and cosine-similarity attention matrices for the trained model are shown below. All values shown are the mean values across the training set (encodings/attention values differ across samples).

Feature Importances


The importances for each module and for each blood serum metabolite / metagenomic genus are shown below. (Note: only three metabolites with absolute importance $>=0.1$ in at least one blood serum module and only 28 genera with absolute importance $>=1$ in at least one metagenomic module are shown).

Interpretation


The metagenomic modules A-D and features are interpreted for each individual blood serum module E-H:

Discussion


The following are some loosely associated concluding points:

Code


The full code is provided here on GitHub, but the main model code is:

 1class MultiOmicsModuleAtt(nn.Module):
 2    def __init__(self, feature_dim_a, feature_dim_b, num_modules_a,
 3                 num_modules_b, hidden_dim, out_dim=1):
 4        """
 5        Args:
 6            feature_dim_a: Da (# features dataset xa)
 7            feature_dim_b: Db (# features dataset xb)
 8            num_modules_a: Number modules desired for dataset xa
 9            num_modules_b: Number modules desired for dataset xb
10            hidden_dim: Final layer hidden dim
11            out_dim: Final output dim (1 for binary classification)
12        """
13        super().__init__()
14        self.num_modules_a = num_modules_a
15        self.num_modules_b = num_modules_b
16        self.encoder_ax = nn.Linear(feature_dim_a, num_modules_a,
17                                    bias=False)
18        self.encoder_ay = nn.Linear(feature_dim_a, num_modules_a,
19                                    bias=False)
20        self.encoder_bx = nn.Linear(feature_dim_b, num_modules_b,
21                                    bias=False)
22        self.encoder_by = nn.Linear(feature_dim_b, num_modules_b,
23                                    bias=False)
24        self.fc1 = nn.Linear((num_modules_b) * 2, hidden_dim)
25        self.fc2 = nn.Linear(hidden_dim, out_dim)
26
27    def forward(self, xa, xb):
28        """
29        Args:
30            xa: Dominant dataset matrix
31                (N, Da) for N (# samples), Da (# features)
32            xb: Subordinate data set (N, Db)
33        """
34        # Encode (x, y) vector coordinates for input datasets xa, xb
35        enc_ax = self.encoder_ax(xa)
36        enc_ay = self.encoder_ay(xa)
37        enc_bx = self.encoder_bx(xb)
38        enc_by = self.encoder_by(xb)
39        # Concatenate (x, y) vector coordinate values
40        enc_a = torch.stack([enc_ax, enc_ay], dim=2)
41        enc_b = torch.stack([enc_bx, enc_by], dim=2)
42        # L2 Normalization of each module vector
43        enc_a = F.normalize(enc_a, p=2, dim=2)
44        enc_b = F.normalize(enc_b, p=2, dim=2)
45        # Raw Attention Matrix (raw cosine similarities)
46        raw_attn = enc_a @ (enc_b.transpose(1, 2))
47        # Softmax Attention matrix (probabilities)
48        attn_a = F.softmax(raw_attn, dim=1)
49        # Attention-encoded modules for xa
50        att_enc_a = attn_a.transpose(1, 2) @ enc_a
51        # Flatten attention-encoded modules, and predict 
52        att_enc_a_flat = att_enc_a.view(-1, self.num_modules_b*2)
53        out = self.fc1(att_enc_a_flat)
54        out = F.relu(out)
55        out = self.fc2(out)
56        return out.squeeze(1)

#Machine Learning #Multiomics #Microbiome