Files
extrapolation/model_vectors_publication_images_v3.py
John Beale 15ea8f8cd5 script dump
2026-02-17 08:52:57 +01:00

426 lines
11 KiB
Python

import gemmi
import numpy as np
import matplotlib.pyplot as plt
# -----------------------------
# Helper: compute SFs
# -----------------------------
def compute_sf(pdb_file, dmin):
st = gemmi.read_structure(pdb_file)
st.setup_entities()
st.assign_label_seq_id()
st.remove_alternative_conformations()
model = st[0]
cell = st.cell
sg = st.find_spacegroup()
if sg is None:
if st.spacegroup_hm:
sg = gemmi.SpaceGroup(st.spacegroup_hm)
else:
raise RuntimeError("No space group found")
hkl = gemmi.make_miller_array(
cell=cell,
spacegroup=sg,
dmin=dmin,
unique=True
)
calc = gemmi.StructureFactorCalculatorX(cell)
F = np.array([
calc.calculate_sf_from_model(model, [int(h), int(k), int(l)])
for h, k, l in hkl
], dtype=np.complex128)
return F, hkl, cell
def add_resolution_dependent_noise(SF, hkl, cell, alpha=0.01, B=20.0, rng=None):
"""
Add resolution-dependent fractional noise to complex structure factors.
Noise is applied to amplitudes only, phases preserved.
"""
F = np.abs(SF)
phase = np.angle(SF)
d = np.array([cell.calculate_d(list(h)) for h in hkl])
if rng is None:
rng = np.random.default_rng()
scale = alpha * np.exp(B / (4.0 * d**2))
noise = rng.normal(0.0, scale) # <-- FIX HERE
F_noise = F * (1.0 + noise)
return F_noise * np.exp(1j * phase), d
def scale_sf_against_ref_mtz(ref_mtz_file, F_gen, hkl_gen,n_bins=20 ):
"""
Scale generated complex structure factors against a reference MTZ
using global RMS scaling followed by resolution-binned RMS refinement.
Parameters
----------
ref_mtz_file : str
Reference MTZ filename
F_gen : ndarray (complex)
Generated structure factors
hkl_gen : ndarray (N,3)
HKLs for generated SFs
n_bins : int
Number of equal-population resolution bins
Returns
-------
Fgen_scaled : ndarray (complex)
Scaled generated SFs (matched reflections only)
hkls_common : list of tuple
HKLs used
d : ndarray
Resolution values (Å)
"""
# -----------------------------
# Read reference MTZ
# -----------------------------
mtz = gemmi.read_mtz_file(ref_mtz_file)
F_col = "F-obs-filtered"
PHI_col = "PHIF-model"
F_ref = mtz.column_with_label(F_col).array
PHI_ref = mtz.column_with_label(PHI_col).array
hkl_ref = mtz.make_miller_array()
cell = mtz.cell
# Build complex reference SFs
F_ref_cplx = F_ref * np.exp(1j * np.deg2rad(PHI_ref))
# -----------------------------
# HKL → SF dictionaries
# -----------------------------
ref_dict = {
tuple(map(int, h)): F_ref_cplx[i]
for i, h in enumerate(hkl_ref)
if F_ref[i] > 0.0
}
gen_dict = {
tuple(map(int, h)): F_gen[i]
for i, h in enumerate(hkl_gen)
}
# -----------------------------
# Find common reflections
# -----------------------------
hkls_common = sorted(ref_dict.keys() & gen_dict.keys())
if len(hkls_common) == 0:
raise RuntimeError("No common HKLs between reference MTZ and generated SFs")
Fref = np.array([ref_dict[h] for h in hkls_common])
Fgen = np.array([gen_dict[h] for h in hkls_common])
# -----------------------------
# Resolution
# -----------------------------
d = np.array([cell.calculate_d(list(h)) for h in hkls_common])
# -----------------------------
# Sort by resolution (low → high)
# -----------------------------
order = np.argsort(d)[::-1]
Fref = Fref[order]
Fgen = Fgen[order]
d = d[order]
hkls_common = [hkls_common[i] for i in order]
# -----------------------------
# GLOBAL RMS SCALING
# -----------------------------
k0 = np.sqrt(
np.mean(np.abs(Fref)**2) /
np.mean(np.abs(Fgen)**2)
)
Fgen *= k0
# -----------------------------
# BIN-WISE RMS REFINEMENT
# -----------------------------
bins = np.array_split(np.arange(len(d)), n_bins)
Fgen_scaled = Fgen.copy()
for b in bins:
if len(b) == 0:
continue
amp_ref = np.abs(Fref[b])
amp_gen = np.abs(Fgen_scaled[b])
den = np.sum(amp_gen**2)
if den <= 0:
continue
k = np.sqrt(np.sum(amp_ref**2) / den)
Fgen_scaled[b] *= k
return Fgen_scaled, hkls_common
def resultant(SF):
R = np.sum(SF) # vector sum
magnitude = np.abs(R) # |sum|
phase_deg = np.angle(R, deg=True) # phase in degrees
return magnitude, phase_deg
def riso_x8(F1, F2):
A1 = np.abs(F1)
A2 = np.abs(F2)
denom = 0.5 * (A1 + A2)
mask = denom > 0
if np.sum(mask) == 0:
return np.nan
num = np.sum(np.abs(A1[mask] - A2[mask]))
den = np.sum(denom[mask])
return num / den
def cciso_x8(F_ref, F_test):
"""
Xtrapol8 definition of CCiso:
Pearson CC of amplitudes F_ref and F_test
"""
A_ref = np.abs(F_ref)
A_test = np.abs(F_test)
# mask out invalids if any
mask = np.isfinite(A_ref) & np.isfinite(A_test)
A_ref = A_ref[mask]
A_test = A_test[mask]
if len(A_ref) < 2:
return np.nan
A_ref_mean = np.mean(A_ref)
A_test_mean = np.mean(A_test)
num = np.sum((A_ref - A_ref_mean)*(A_test - A_test_mean))
den = np.sqrt(np.sum((A_ref - A_ref_mean)**2)*np.sum((A_test - A_test_mean)**2))
return num/den if den != 0 else np.nan
def make_equal_population_resolution_bins( d, nbins=20):
d = np.asarray(d)
# Sort reflections by resolution (low → high res)
order = np.argsort(d)[::-1] # large d first
d_sorted = d[order]
n = len(d)
bin_size = n // nbins
bins = []
for i in range(nbins):
start = i * bin_size
end = (i + 1) * bin_size if i < nbins - 1 else n
idx = order[start:end]
bins.append({
"dmax": d_sorted[start], # low resolution edge
"dmin": d_sorted[end - 1], # high resolution edge
"indices": idx
})
return bins
def match_hkls(F1, hkl1, F2, hkl2, F3, hkl3):
"""
Match two SF arrays on common HKLs.
Returns
-------
F1c, F2c : np.ndarray (complex)
hkl_common : np.ndarray (N, 3)
"""
dict1 = {tuple(h): F1[i] for i, h in enumerate(hkl1)}
dict2 = {tuple(h): F2[i] for i, h in enumerate(hkl2)}
dict3 = {tuple(h): F3[i] for i, h in enumerate(hkl3)}
common = sorted(dict1.keys() & dict2.keys() & dict3.keys())
if len(common) == 0:
raise RuntimeError("No common HKLs between datasets")
F1c = np.array([dict1[h] for h in common])
F2c = np.array([dict2[h] for h in common])
F3c = np.array([dict3[h] for h in common])
hklc = np.array(common, dtype=int)
return F1c, F2c, F3c, hklc
def riso_cciso_vs_resolution(Fref, Ftest, d, nbins, min_per_bin=20):
bins = make_equal_population_resolution_bins(d, nbins)
dmid = []
riso_vals = []
cciso_vals = []
for b in bins:
idx = b["indices"]
if len(idx) < min_per_bin:
continue
F1 = Fref[idx]
F2 = Ftest[idx]
dmid.append(0.5 * (b["dmin"] + b["dmax"]))
riso_vals.append(riso(F1, F2))
cciso_vals.append(cciso(F1, F2))
return np.array(dmid), np.array(riso_vals), np.array(cciso_vals)
def plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals):
fig, ax1 = plt.subplots()
ax1.plot(1/d_centers, riso_vals, color="red", marker='o')
ax1.set_xlabel("Resolution (Å)")
ax1.set_ylabel("Riso", color="red")
ax2 = ax1.twinx()
ax2.plot(1/d_centers, cciso_vals, color="blue", marker='s')
ax2.set_ylabel("CCiso", color="blue")
plt.title("Riso and CCiso vs Resolution")
plt.tight_layout()
plt.show()
# -----------------------------
# User inputs
# -----------------------------
#apo_pdb = "../data/hewl_apo.pdb"
#apo_pdb = "../data/ser-arg-loop-apo.pdb"
#occ100_pdb = "../data/hewl_1.0-I.pdb"
#occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb"
#occb_pdb = "../data/hewl_0.1-I.pdb"
#occb_pdb = "../data/ser-arg-loop-0.1-switch.pdb"
b_val = 0.1
a_val = 1.0 - b_val
dmin = 1.5
alpha = 0.0001
ref_mtz = "../data/apo_100k_refine_5.mtz"
occ_lst = [ 0.01, 0.05, 0.1, 0.3, 0.5, 0.7, 0.9 ]
print( "calculating apo" )
apo_pdb = "../data/hewl_apo.pdb"
#apo_pdb = "../data/ser-arg-loop-apo.pdb"
SF_apo, hkl_apo, cell_apo = compute_sf(apo_pdb, dmin)
print( "done" )
print( "calculating 100 %" )
occ100_pdb = "../data/hewl_1.0-I.pdb"
#occ100_pdb = "../data/ser-arg-loop-1.0-switch.pdb"
SF_100, hkl_100, cell_100 = compute_sf(occ100_pdb, dmin)
print( "done" )
for b_val in occ_lst:
print( "occ = {0}".format( b_val ) )
a_val = 1 - b_val
print( "calculating b SFs" )
occb_pdb = "../data/hewl_{0}-I.pdb".format( b_val )
#occb_pdb = "../data/ser-arg-loop-{0}-switch.pdb".format( b_val )
SF_b, hkl_b, cell_b = compute_sf(occb_pdb, dmin)
print( "done" )
SF_apo, SF_100, SF_b, hklc = match_hkls(SF_apo, hkl_apo, SF_100, hkl_100, SF_b, hkl_b)
d = np.array([cell_apo.calculate_d(list(h)) for h in hklc])
F_apo = np.abs(SF_apo)
phi_apo = np.angle(SF_apo)
F_100 = np.abs(SF_100)
phi_100 = np.angle(SF_100)
F_b = np.abs(SF_b)
phi_b = np.angle(SF_b)
F_extr = (F_b - F_apo)/b_val + F_apo
SF_extr = F_extr * np.exp(1j * phi_apo)
# d_centers, riso_vals, cciso_vals = riso_cciso_vs_resolution(SF_100, SF_extr, d, 20)
# plot_riso_cciso_vs_resolution(d_centers, riso_vals, cciso_vals)
delta_Fextr = SF_extr - SF_100
riso = riso_x8(SF_apo, SF_b)
print( "Riso = {0}".format( riso ) )
cciso = cciso_x8(SF_apo, SF_b)
print( "CCiso = {0}".format( cciso ) )
#print("Fextr TEST: SF_extr = (F_b - F_apo)/b_val + F_apo")
print("Mean |ΔF|:", np.mean(np.abs(delta_Fextr)))
print("Mean phase Δ (deg):",
np.mean(np.abs(np.angle(SF_extr / SF_100, deg=True))))
#Fextr, phi_extr = resultant(SF_extr)
#F_100_model, phi_100_model = resultant(SF_100)
#print("Sum SF_extr calculated: |F| = {0}, phi = {1}".format( F_extr, phi_extr ) )
#print("Sum SF_100 model : |F| = {0}, phi = {1}".format( F_100_model, phi_100_model ) )
# -----------------------------
# Resultant vectors for plotting
# -----------------------------
sum_SF_apo = np.sum(SF_apo)
sum_SF_100 = np.sum(SF_100)
sum_SF_b = np.sum(SF_b)
sum_SF_extr = np.sum(SF_extr)
# -----------------------------
# Plot resultant vectors
# -----------------------------
plt.figure(figsize=(8,8))
ax = plt.gca()
ax.set_aspect('equal')
origin = 0 + 0j
ax.arrow(origin.real, origin.imag, sum_SF_apo.real, sum_SF_apo.imag,
head_width=500, head_length=500, fc='blue', ec='blue', label='Fapo model')
ax.arrow(origin.real, origin.imag, sum_SF_100.real, sum_SF_100.imag,
head_width=500, head_length=500, fc='red', ec='red', label='F100 model')
ax.arrow(origin.real, origin.imag, sum_SF_b.real, sum_SF_b.imag,
head_width=500, head_length=500, fc='green', ec='green', label='Fab model')
ax.arrow(origin.real, origin.imag, sum_SF_extr.real, sum_SF_extr.imag,
head_width=500, head_length=500, fc='black', ec='black', label='Fextr')
ax.set_xlabel("Re(F)")
ax.set_ylabel("Im(F)")
ax.set_title("SF Resultant Vectors")
ax.grid(True)
ax.legend()
#plt.show()