diff --git a/generate_combined_pdb.py b/generate_combined_pdb.py index 9e202ad..de5c7c9 100644 --- a/generate_combined_pdb.py +++ b/generate_combined_pdb.py @@ -5,6 +5,60 @@ import argparse import tempfile import os import sys +import yaml + +# ----------------------------- +# load YAML +# ----------------------------- +def load_yaml(file): + with open(file) as f: + return yaml.safe_load(f) + + +# ----------------------------- +# build site atom map +# ----------------------------- +def build_site_map(rois): + + site_map = {} + + for site, info in rois["regions"].items(): + + atoms = [] + + for a in info["atoms"]: + + atoms.append({ + "atom": a["atom"], + "resn": a["res_name"], + "chain": a["chain"], + "resid": a["resid"] + }) + + site_map[site] = atoms + + return site_map + + +def is_mainchain(atom_name): + return atom_name in {"N", "CA", "C", "O"} + + +def get_site_for_atom(atom_name, chain, resid, resn, site_map): + + for site, atoms in site_map.items(): + + for a in atoms: + + if ( + chain == a["chain"] and + resid == a["resid"] and + resn == a["resn"] + ): + # FULL RESIDUE MATCH + return site + + return None # ----------------------------- @@ -133,6 +187,27 @@ def report_state_differences(dark, light): print() +# ----------------------------- +# lookup and check if waters are in both light and dark +# ----------------------------- +def get_atom_keys(struct): + + keys = set() + + for chain in struct[0]: + for res in chain: + for atom in res: + + keys.add(( + chain.name, + res.seqid.num, + res.name, + atom.name + )) + + return keys + + # ----------------------------- # occupancy sanity check # ----------------------------- @@ -216,9 +291,11 @@ def check_missing_light_residues(dark, light, ensemble): # ----------------------------- # Build ensemble model # ----------------------------- -def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False): +def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False, + site_occ=None, site_map=None): - dark_occ = 1.0 - light_occ + dark_keys = get_atom_keys(dark) + light_keys = get_atom_keys(light) ensemble = dark.clone() model = ensemble[0] @@ -232,7 +309,28 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False): for res in chain: for atom in res: - atom.occ *= dark_occ + # default dark occupancy + occ = 1.0 - light_occ + + # override if site-specific + if site_map and site_occ: + site = get_site_for_atom( + atom.name, + chain.name, + res.seqid.num, + res.name, + site_map + ) + if site: + occ = 1.0 - site_occ.get(site, light_occ) + + key = (chain.name, res.seqid.num, res.name, atom.name) + + if key not in light_keys: + # no light counterpart → keep full occupancy + atom.occ *= 1.0 + else: + atom.occ *= occ if atom.occ < occ_cutoff: continue @@ -242,6 +340,17 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False): dark_atoms += 1 + if site: + print( + "SITE MATCH:", + site, + chain.name, + res.name, + res.seqid.num, + atom.name, + "→ occ =", occ + ) + light_atoms = 0 # ----------------- @@ -274,15 +383,34 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False): new_res.seqid = gemmi.SeqId(res.seqid.num, res.seqid.icode) target_chain.add_residue(new_res) - - # IMPORTANT: retrieve the actual residue stored in the chain target_res = target_chain[-1] res_lookup[key] = target_res for atom in res: - new_occ = atom.occ * light_occ + # default light occupancy + occ = light_occ + + # override if site-specific + if site_map and site_occ: + site = get_site_for_atom( + atom.name, + chain.name, + res.seqid.num, + res.name, + site_map + ) + if site: + occ = site_occ.get(site, light_occ) + + key = (chain.name, res.seqid.num, res.name, atom.name) + + if key not in dark_keys: + # no dark counterpart → full occupancy + new_occ = atom.occ * 1.0 + else: + new_occ = atom.occ * occ if new_occ < occ_cutoff: @@ -308,6 +436,18 @@ def build_ensemble(dark, light, light_occ, occ_cutoff=0.005, verbose=False): target_res.add_atom(new_atom) + + if site: + print( + "SITE MATCH:", + site, + chain.name, + res.name, + res.seqid.num, + atom.name, + "→ occ =", occ + ) + light_atoms += 1 print() @@ -369,6 +509,16 @@ def main(): help="Atoms with scaled occupancy below this value are removed" ) + parser.add_argument( + "--site-occ", + help="YAML file with site occupancies" + ) + + parser.add_argument( + "--roi", + help="ROI YAML file" + ) + parser.add_argument( "-v", "--verbose", @@ -403,12 +553,28 @@ def main(): # report any residue differences report_state_differences(dark, light) + site_occ = None + site_map = None + + if args.site_occ and args.roi: + + site_occ_data = load_yaml(args.site_occ) + roi_data = load_yaml(args.roi) + + site_occ = site_occ_data["site_occupancies"] + site_map = build_site_map(roi_data) + + print("Using site-specific occupancies:") + print(site_occ) + ensemble = build_ensemble( dark, light, args.fraction, occ_cutoff=args.occ_cutoff, - verbose=args.verbose + verbose=args.verbose, + site_occ=site_occ, + site_map=site_map ) # check all the residues are there