Potts model#

In this tutorial, we will learn a Potts model for a multitple sequence alignment using the pseudo-likelihood method. Then we will plot the coupling terms of the learned model and see how they are related to contacts in the native structure of the protein.

[16]:
import urllib3
import gzip
import pickle
import numpy as np
import os
import jax.numpy as jnp
import jax
import optax
import gemmi
from gemmi import cif
import matplotlib.pyplot as plt
plt.style.use('ggplot')

Process multiple sequence alignment#

The protein family PF00041 will be used as an example. More information about this family can be found at Pfam. First, we need to download the multiple sequence alignment (MSA) of this family from Pfam.

[2]:
pfam_id = "PF00041"

if not os.path.exists(f"./data/{pfam_id}_full.txt"):
    http = urllib3.PoolManager()
    r = http.request(
        "GET",
        f"https://www.ebi.ac.uk/interpro/wwwapi//entry/pfam/{pfam_id}/?annotation=alignment:full&download",
    )
    data = gzip.decompress(r.data)
    data = data.decode()
    with open(f"./data/{pfam_id}_full.txt".format(pfam_id), "w") as file_handle:
        print(data, file=file_handle)

Read the MSA file into a dict.

[3]:
msa = {}
with open(f"./data/{pfam_id}_full.txt", "r") as file_handle:
    for line in file_handle:
        if line.startswith("#"):
            continue
        if line.startswith("//"):
            break

        line = line.strip().split()
        if len(line) == 0:
            continue

        name = line[0]
        seq = line[1]
        msa[name] = seq.upper()

The original MSA might contain sequences that have too many gaps or positions where too many sequences have gaps. Before using the MSA to learn a Potts model, we will filter it to remove such sequences and positions. Such filtering is important to ensure that the learned model is not biased by sequences or positions that are not informative. We will use the sequence TENA_HUMAN/804-891 as the reference sequence in the filtering process.

[18]:
query_seq_id = "TENA_HUMAN/804-891"

## remove positions where the query sequence has gaps
query_seq = msa[query_seq_id]
idx = [ s == "-" or s == "." for s in query_seq]
for k in msa.keys():
    msa[k] = [msa[k][i] for i in range(len(msa[k])) if not idx[i]]
query_seq = msa[query_seq_id]

## remove sequences that have more than 20% gaps
len_query_seq = len(query_seq)
seq_id = list(msa.keys())
for k in seq_id:
    if msa[k].count("-") + msa[k].count(".") >= len_query_seq * 0.20:
        msa.pop(k)
[6]:
aa = "RHKDESTNQCGPAVILMFYW"

aa_index = {}
aa_index["-"] = 0
aa_index["."] = 0
i = 1
for a in aa:
    aa_index[a] = i
    i += 1
with open("./output/aa_index.pkl", "wb") as file_handle:
    pickle.dump(aa_index, file_handle)

seq_msa = []
for k in msa.keys():
    if msa[k].count("X") > 0 or msa[k].count("Z") > 0 or msa[k].count("B") > 0:
        continue
    seq_msa.append([aa_index[s] for s in msa[k]])
seq_msa = np.array(seq_msa, dtype=np.int8)
[7]:
## remove positions where too many sequences have gaps
pos_idx = []
for i in range(seq_msa.shape[1]):
    if np.sum(seq_msa[:,i] == 0) <= seq_msa.shape[0]*0.2:
        pos_idx.append(i)
with open("./output/seq_pos_idx.pkl", 'wb') as file_handle:
    pickle.dump(pos_idx, file_handle)

seq_msa = seq_msa[:, np.array(pos_idx)]
with open("./output/seq_msa.pkl", 'wb') as file_handle:
    pickle.dump(seq_msa, file_handle)

It might also contain sequences that are very similar to each other.

[8]:
## reweighting sequences
seq_weight = np.zeros(seq_msa.shape)
for j in range(seq_msa.shape[1]):
    aa_type, aa_counts = np.unique(seq_msa[:,j], return_counts = True)
    num_type = len(aa_type)
    aa_dict = {}
    for a in aa_type:
        aa_dict[a] = aa_counts[list(aa_type).index(a)]
    for i in range(seq_msa.shape[0]):
        seq_weight[i,j] = (1.0/num_type) * (1.0/aa_dict[seq_msa[i,j]])
tot_weight = np.sum(seq_weight)
seq_weight = seq_weight.sum(1) / tot_weight
with open("./output/seq_weight.pkl", 'wb') as file_handle:
    pickle.dump(seq_weight, file_handle)

Learn the model#

[9]:
L = seq_msa.shape[1]
K = 21

params = {
    "h": jnp.zeros((K, L)),
    "J": jnp.zeros((K, K, L, L)),
}

mask_J = np.ones((K, K, L, L), dtype=bool)
mask_J[:, :, range(L), range(L)] = False
mask_J = jnp.array(mask_J)


def compute_log_pseudo_likelihood_pos(param, i, seq):
    h = param["h"]
    J = param["J"]
    J = jnp.where(mask_J, J, 0.0)
    J = 0.5 * (J + jnp.transpose(J, (0, 1, 3, 2)))

    hi = h[:, i]
    Ji = J[:, seq, i, jnp.arange(L)]

    u = hi + jnp.sum(Ji, axis=1)
    cross_entropy = optax.softmax_cross_entropy_with_integer_labels(
        logits=-u[jnp.newaxis, :], labels=seq[jnp.newaxis, i]
    )

    return -cross_entropy


def compute_log_pseudo_likelihood_seq(param, seq):
    ii = jnp.arange(L)
    pseudo_log_likelihood = jax.vmap(
        compute_log_pseudo_likelihood_pos, in_axes=(None, 0, None)
    )(param, ii, seq)
    return jnp.sum(pseudo_log_likelihood)


def compute_log_pseudo_likelihood_msa(param, seq_msa, seq_weight):
    log_pseudo_likelihood = jax.vmap(
        compute_log_pseudo_likelihood_seq, in_axes=(None, 0)
    )(param, seq_msa)
    log_pseudo_likelihood = jnp.sum(log_pseudo_likelihood * seq_weight)
    return log_pseudo_likelihood


batch_size = 1028
num_batches = seq_msa.shape[0] // batch_size + 1
if seq_msa.shape[0] < batch_size * num_batches:
    seq_msa = jnp.concatenate(
        (
            seq_msa,
            jnp.zeros((batch_size * num_batches - seq_msa.shape[0], L), dtype=jnp.int8),
        )
    )
    seq_weight = jnp.concatenate(
        (
            seq_weight,
            jnp.zeros(
                (batch_size * num_batches - seq_weight.shape[0],), dtype=jnp.int8
            ),
        )
    )

seq_msa_batches = jnp.reshape(seq_msa, (num_batches, batch_size, L))
seq_weight_batches = jnp.reshape(seq_weight, (num_batches, batch_size))

batches = {
    "msa": seq_msa_batches,
    "weight": seq_weight_batches,
}

weight_decay = 0.05


def compute_loss(params):
    _, y = jax.lax.scan(
        jax.checkpoint(
            lambda carry, x: (
                carry,
                compute_log_pseudo_likelihood_msa(params, x["msa"], x["weight"]),
            )
        ),
        None,
        batches,
    )
    loss = -jnp.sum(y)
    loss += weight_decay * jnp.sum(params["J"] ** 2)

    return loss


v, g = jax.value_and_grad(compute_loss)(params)


solver = optax.lbfgs()
opt_state = solver.init(params)
value_and_grad = optax.value_and_grad_from_state(compute_loss)
for _ in range(20):
    value, grad = value_and_grad(params, state=opt_state)
    updates, opt_state = solver.update(
        grad, opt_state, params, value=value, grad=grad, value_fn=compute_loss
    )
    params = optax.apply_updates(params, updates)
    print("Objective function: {:.2E}".format(compute_loss(params)))

with open("./output/params.pkl", "wb") as file_handle:
    pickle.dump(params, file_handle)

Objective function: 2.09E+02
Objective function: 1.91E+02
Objective function: 1.71E+02
Objective function: 1.68E+02
Objective function: 1.66E+02
Objective function: 1.65E+02
Objective function: 1.63E+02
Objective function: 1.63E+02
Objective function: 1.62E+02
Objective function: 1.62E+02
Objective function: 1.61E+02
Objective function: 1.61E+02
Objective function: 1.60E+02
Objective function: 1.60E+02
Objective function: 1.59E+02
Objective function: 1.58E+02
Objective function: 1.57E+02
Objective function: 1.57E+02
Objective function: 1.56E+02
Objective function: 1.56E+02
[10]:
doc = cif.read("./data/1TEN.cif")
block = doc.sole_block()

seq = list(block.find_loop("_atom_site"))

cif_block = gemmi.cif.read("./data/1TEN.cif")[0]
structure = gemmi.make_structure_from_block(cif_block)

model = structure[0]
chain = model[0]
residues = [res for res in chain if not res.is_water()]


def compute_min_distance(res1, res2):
    """Compute the minimum distance between two residues only considering heavy atoms."""
    min_distance = float("inf")
    for atom1 in res1:
        if atom1.is_hydrogen():
            continue
        for atom2 in res2:
            if atom2.is_hydrogen():
                continue
            distance = atom1.pos.dist(atom2.pos)
            if distance < min_distance:
                min_distance = distance
    return min_distance


distances = np.zeros((len(residues), len(residues)))
for i, res1 in enumerate(residues):
    for j, res2 in enumerate(residues):
        if i != j:
            distances[i, j] = compute_min_distance(res1, res2)

with open("./output/distances_from_structure.pkl", "wb") as f:
    pickle.dump(distances, f)
[17]:
with open("./output/params.pkl", "rb") as f:
    params = pickle.load(f)


J = np.array(params["J"])
L = J.shape[-1]

## calculate interaction scores
Jp = {}
score_FN = np.zeros([L, L])
for i in range(L):
    for j in range(i + 1, L):
        J_prime = J[:, :, i, j]
        J_prime = (
            J_prime
            - J_prime.mean(0).reshape([1, -1])
            - J_prime.mean(1).reshape([-1, 1])
            + J_prime.mean()
        )
        Jp[(i, j)] = J_prime
        score_FN[i, j] = np.sqrt(np.sum(J_prime * J_prime))
        score_FN[j, i] = score_FN[i, j]
score_CN = score_FN - score_FN.mean(1).reshape([-1, 1]).dot(
    score_FN.mean(0).reshape([1, -1])
) / np.mean(score_FN)


for i in range(score_CN.shape[0]):
    for j in range(score_CN.shape[1]):
        if abs(i-j) <= 4:
            score_CN[i,j] = -np.inf

tmp = np.copy(score_CN).reshape([-1])
tmp.sort()
cutoff = tmp[-80*2]
contact_plm = score_CN > cutoff
for j in range(contact_plm.shape[0]):
    for i in range(j, contact_plm.shape[1]):
        contact_plm[i,j] = False

with open('./output/distances_from_structure.pkl', 'rb') as f:
    distances = pickle.load(f)

cutoff = 6
contact_pdb = distances < cutoff
for i in range(contact_pdb.shape[0]):
    for j in range(contact_pdb.shape[1]):
        if abs(i-j) <= 4:
            contact_pdb[i,j] = False
        if j <= i:
            contact_pdb[i,j] = False

with open("./output/seq_pos_idx.pkl", "rb") as f:
    seq_pos_idx = np.array(pickle.load(f))

offset = 2
contact_pdb = contact_pdb[seq_pos_idx+offset, :][:, seq_pos_idx+offset]


fig = plt.figure(figsize = (10,10))
fig.clf()
I,J = np.where(contact_pdb)
plt.plot(I,J, 'bo', alpha = 0.2, markersize = 8, label = 'native contacts from PDB')
#plt.imshow(contact_pdb, cmap = "binary", alpha = 0.5)
I, J = np.where(contact_plm)
plt.plot(I,J, 'r^', markersize = 6, mew = 1.5, label = 'predicted contacts from Potts model')
plt.gca().set_aspect('equal', adjustable='box')

# plt.xlim((0,153))
# plt.ylim((0,153))
#plt.title(protein)
#plt.legend()
plt.savefig("./output/contact.png")
../../_images/tutorial_potts-model_main_15_0.png