enabled changes at specific rois

This commit is contained in:
2026-03-24 13:50:25 +01:00
parent d16c968482
commit be5bac2778
+173 -7
View File
@@ -5,6 +5,60 @@ import argparse
import tempfile
import os
import sys
import yaml
# -----------------------------
# load YAML
# -----------------------------
def load_yaml(file):
with open(file) as f:
return yaml.safe_load(f)
# -----------------------------
# build site atom map
# -----------------------------
def build_site_map(rois):
site_map = {}
for site, info in rois["regions"].items():
atoms = []
for a in info["atoms"]:
atoms.append({
"atom": a["atom"],
"resn": a["res_name"],
"chain": a["chain"],
"resid": a["resid"]
})
site_map[site] = atoms
return site_map
def is_mainchain(atom_name):
return atom_name in {"N", "CA", "C", "O"}
def get_site_for_atom(atom_name, chain, resid, resn, site_map):
for site, atoms in site_map.items():
for a in atoms:
if (
chain == a["chain"] and
resid == a["resid"] and
resn == a["resn"]
):
# FULL RESIDUE MATCH
return site
return None
# -----------------------------
@@ -133,6 +187,27 @@ def report_state_differences(dark, light):
print()
# -----------------------------
# lookup and check if waters are in both light and dark
# -----------------------------
def get_atom_keys(struct):
keys = set()
for chain in struct[0]:
for res in chain:
for atom in res:
keys.add((
chain.name,
res.seqid.num,
res.name,
atom.name
))
return keys
# -----------------------------
# occupancy sanity check
# -----------------------------
@@ -216,9 +291,11 @@ def check_missing_light_residues(dark, light, ensemble):
# -----------------------------
# Build ensemble model
# -----------------------------
def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False):
def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False,
site_occ=None, site_map=None):
dark_occ = 1.0 - light_occ
dark_keys = get_atom_keys(dark)
light_keys = get_atom_keys(light)
ensemble = dark.clone()
model = ensemble[0]
@@ -232,7 +309,28 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False):
for res in chain:
for atom in res:
atom.occ *= dark_occ
# default dark occupancy
occ = 1.0 - light_occ
# override if site-specific
if site_map and site_occ:
site = get_site_for_atom(
atom.name,
chain.name,
res.seqid.num,
res.name,
site_map
)
if site:
occ = 1.0 - site_occ.get(site, light_occ)
key = (chain.name, res.seqid.num, res.name, atom.name)
if key not in light_keys:
# no light counterpart → keep full occupancy
atom.occ *= 1.0
else:
atom.occ *= occ
if atom.occ < occ_cutoff:
continue
@@ -242,6 +340,17 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False):
dark_atoms += 1
if site:
print(
"SITE MATCH:",
site,
chain.name,
res.name,
res.seqid.num,
atom.name,
"→ occ =", occ
)
light_atoms = 0
# -----------------
@@ -274,15 +383,34 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False):
new_res.seqid = gemmi.SeqId(res.seqid.num, res.seqid.icode)
target_chain.add_residue(new_res)
# IMPORTANT: retrieve the actual residue stored in the chain
target_res = target_chain[-1]
res_lookup[key] = target_res
for atom in res:
new_occ = atom.occ * light_occ
# default light occupancy
occ = light_occ
# override if site-specific
if site_map and site_occ:
site = get_site_for_atom(
atom.name,
chain.name,
res.seqid.num,
res.name,
site_map
)
if site:
occ = site_occ.get(site, light_occ)
key = (chain.name, res.seqid.num, res.name, atom.name)
if key not in dark_keys:
# no dark counterpart → full occupancy
new_occ = atom.occ * 1.0
else:
new_occ = atom.occ * occ
if new_occ < occ_cutoff:
@@ -308,6 +436,18 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False):
target_res.add_atom(new_atom)
if site:
print(
"SITE MATCH:",
site,
chain.name,
res.name,
res.seqid.num,
atom.name,
"→ occ =", occ
)
light_atoms += 1
print()
@@ -369,6 +509,16 @@ def main():
help="Atoms with scaled occupancy below this value are removed"
)
parser.add_argument(
"--site-occ",
help="YAML file with site occupancies"
)
parser.add_argument(
"--roi",
help="ROI YAML file"
)
parser.add_argument(
"-v",
"--verbose",
@@ -403,12 +553,28 @@ def main():
# report any residue differences
report_state_differences(dark, light)
site_occ = None
site_map = None
if args.site_occ and args.roi:
site_occ_data = load_yaml(args.site_occ)
roi_data = load_yaml(args.roi)
site_occ = site_occ_data["site_occupancies"]
site_map = build_site_map(roi_data)
print("Using site-specific occupancies:")
print(site_occ)
ensemble = build_ensemble(
dark,
light,
args.fraction,
occ_cutoff=args.occ_cutoff,
verbose=args.verbose
verbose=args.verbose,
site_occ=site_occ,
site_map=site_map
)
# check all the residues are there