Files
extrapolation/compare_mtz.py
John Beale 15ea8f8cd5 script dump
2026-02-17 08:52:57 +01:00

207 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
import gemmi
import matplotlib.pyplot as plt
def riso_xtrapol8(F1, F2):
A1 = np.abs(F1)
A2 = np.abs(F2)
denom = 0.5 * (A1 + A2)
mask = denom > 0
if np.sum(mask) == 0:
return np.nan
num = np.sum(np.abs(A1[mask] - A2[mask]))
den = np.sum(denom[mask])
return num / den
def cciso_xtrapol8(F1, F2):
A1 = np.abs(F1)
A2 = np.abs(F2)
mask = np.isfinite(A1) & np.isfinite(A2)
A1 = A1[mask]
A2 = A2[mask]
if len(A1) < 2:
return np.nan
A1m = A1 - np.mean(A1)
A2m = A2 - np.mean(A2)
num = np.sum(A1m * A2m)
den = np.sqrt(np.sum(A1m**2) * np.sum(A2m**2))
return num / den if den > 0 else np.nan
def read_mtz_amplitudes(mtz_file, amp_label):
mtz = gemmi.read_mtz_file(mtz_file)
H = np.array(mtz.column_with_label('H'), dtype=int)
K = np.array(mtz.column_with_label('K'), dtype=int)
L = np.array(mtz.column_with_label('L'), dtype=int)
hkl = np.vstack([H, K, L]).T
F = np.array(mtz.column_with_label(amp_label), dtype=float)
cell = mtz.cell
return hkl, F, cell
def match_reflections(hkl1, F1, hkl2, F2):
map1 = {tuple(h): f for h, f in zip(hkl1, F1)}
map2 = {tuple(h): f for h, f in zip(hkl2, F2)}
common = sorted(set(map1.keys()) & set(map2.keys()))
F1m = np.array([map1[h] for h in common])
F2m = np.array([map2[h] for h in common])
hklm = np.array(common, dtype=int)
return hklm, F1m, F2m
def compute_resolution(cell, hkl):
return np.array([cell.calculate_d(list(h)) for h in hkl])
def make_equal_count_bins(d, n_bins):
order = np.argsort(d)[::-1] # low → high resolution
bins = np.array_split(order, n_bins)
return bins
def make_equal_count_resolution_bins(d, n_bins):
"""
Returns a list of index arrays, each containing
roughly the same number of reflections.
"""
order = np.argsort(d)[::-1] # low → high resolution
return np.array_split(order, n_bins)
def riso_cciso_by_resolution(d, F1, F2, n_bins=20):
"""
Compute Xtrapol8 Riso and CCiso in equal-count resolution bins.
Returns:
d_mid : midpoint resolution per bin
Riso : array
CCiso : array
counts : reflections per bin
"""
bins = make_equal_count_resolution_bins(d, n_bins)
d_mid = []
Riso = []
CCiso = []
counts = []
for idx in bins:
if len(idx) < 10:
continue
d_bin = d[idx]
F1b = F1[idx]
F2b = F2[idx]
d_mid.append(0.5 * (d_bin.min() + d_bin.max()))
Riso.append(riso_xtrapol8(F1b, F2b))
CCiso.append(cciso_xtrapol8(F1b, F2b))
counts.append(len(idx))
return (
np.array(d_mid),
np.array(Riso),
np.array(CCiso),
np.array(counts),
)
def plot_riso_cciso_vs_resolution(d_mid, Riso, CCiso):
fig, ax1 = plt.subplots(figsize=(6, 4))
# Riso
ax1.plot(1/d_mid, Riso, marker='o', label='Riso')
ax1.set_xlabel('Resolution (Å)')
ax1.set_ylabel('Riso')
ax1.grid(True, alpha=0.3)
# CCiso
ax2 = ax1.twinx()
ax2.plot(1/d_mid, CCiso, marker='s', linestyle='--', label='CCiso')
ax2.set_ylabel('CCiso')
ax2.set_ylim(0, 1)
# Combined legend
lines, labels = ax1.get_legend_handles_labels()
lines2, labels2 = ax2.get_legend_handles_labels()
ax1.legend(lines + lines2, labels + labels2, loc='best')
plt.tight_layout()
plt.show()
def riso_cciso_from_mtz(
mtz1, mtz2,
amp1_label, amp2_label,
n_bins=None
):
hkl1, F1, cell1 = read_mtz_amplitudes(mtz1, amp1_label)
hkl2, F2, cell2 = read_mtz_amplitudes(mtz2, amp2_label)
hkl, F1m, F2m = match_reflections(hkl1, F1, hkl2, F2)
d = compute_resolution(cell1, hkl)
print(f"Matched reflections: {len(hkl)}")
# Overall statistics
Riso = riso_xtrapol8(F1m, F2m)
CCiso = cciso_xtrapol8(F1m, F2m)
print(f"\nOverall:")
print(f" Riso = {Riso:.4f}")
print(f" CCiso = {CCiso:.4f}")
# Per-bin statistics
if n_bins is not None:
print(f"\nPer-resolution-bin statistics:")
bins = make_equal_count_bins(d, n_bins)
for i, idx in enumerate(bins):
if len(idx) < 10:
continue
Rb = riso_xtrapol8(F1m[idx], F2m[idx])
CCb = cciso_xtrapol8(F1m[idx], F2m[idx])
dmin = np.min(d[idx])
dmax = np.max(d[idx])
print(
f"Bin {i+1:02d} "
f"{dmax:5.2f}{dmin:5.2f} Å "
f"Riso={Rb:6.3f} CCiso={CCb:6.3f} n={len(idx)}"
)
d_mid, Rb, CCb, nref = riso_cciso_by_resolution( d, F1m, F2m, n_bins=20 )
plot_riso_cciso_vs_resolution(d_mid, Rb, CCb)
return Riso, CCiso
mtz1 = "../data/apo_100k_refine_5.mtz"
mtz2 = "../data/on_100k_refine_141.mtz"
amp1_label = "F-obs-filtered"
amp2_label = "F-obs-filtered"
riso_cciso_from_mtz( mtz1, mtz2, amp1_label, amp2_label, n_bins=None)