|
| 1 | +""" |
| 2 | +scaffoldgraph tests.vis.test_vis_utils |
| 3 | +""" |
| 4 | + |
| 5 | +import scaffoldgraph.vis.utils as vis_utils |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import random |
| 8 | +import pytest |
| 9 | +import re |
| 10 | + |
| 11 | +from rdkit.Chem.Draw import rdMolDraw2D |
| 12 | +from rdkit import Chem |
| 13 | + |
| 14 | +from scaffoldgraph.utils import suppress_rdlogger |
| 15 | +from . import long_test_network |
| 16 | + |
| 17 | + |
| 18 | +SVG_PATTERN = r'(?:<\?xml\b[^>]*>[^<]*)?(?:<!--.*?-->[^<]*)*(?:<svg|<!DOCTYPE svg)\b' |
| 19 | +SVG_REGEX = re.compile(SVG_PATTERN, re.DOTALL) |
| 20 | + |
| 21 | +SVG_DIM_PATTERN = r"width='(\d+px)'\s+height='(\d+px)" |
| 22 | +SVG_DIM_REGEX = re.compile(SVG_DIM_PATTERN) |
| 23 | + |
| 24 | +HEX_PATTERN = r'^#([A-Fa-f0-9]{6}|[A-Fa-f0-9]{3})$' |
| 25 | +HEX_REGEX = re.compile(HEX_PATTERN) |
| 26 | + |
| 27 | + |
| 28 | +def naive_svg_check(svg_string): |
| 29 | + """Validate SVG format (naive).""" |
| 30 | + return SVG_REGEX.match(svg_string) is not None |
| 31 | + |
| 32 | + |
| 33 | +def svg_dimensions(svg_string): |
| 34 | + """Return dimensions of an (rdkit) SVG string.""" |
| 35 | + matches = SVG_DIM_REGEX.findall(svg_string) |
| 36 | + if not matches: |
| 37 | + return (None, None) |
| 38 | + dims = map(lambda x: int(x.replace('px', '')), matches[0]) |
| 39 | + return tuple(dims) |
| 40 | + |
| 41 | + |
| 42 | +def insert_random_node_attribute(graph, key, high=1, low=0): |
| 43 | + """Add a random attribute to nodes in a graph.""" |
| 44 | + for _, data in graph.nodes(data=True): |
| 45 | + value = random.uniform(low, high) |
| 46 | + data[key] = value |
| 47 | + |
| 48 | + |
| 49 | +def is_valid_hex(hex): |
| 50 | + """Validate hexadecimal color code.""" |
| 51 | + if hex is None: |
| 52 | + return False |
| 53 | + if HEX_REGEX.search(hex): |
| 54 | + return True |
| 55 | + return False |
| 56 | + |
| 57 | + |
| 58 | +def test_smiles_to_svg(): |
| 59 | + smi = 'Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1' |
| 60 | + img = vis_utils.smiles_to_svg(smi) # smiles to SVG |
| 61 | + assert img is not None |
| 62 | + assert naive_svg_check(img) is True |
| 63 | + # Check size updates. |
| 64 | + dims = (450, 400) |
| 65 | + img = vis_utils.smiles_to_svg(smi, size=dims) |
| 66 | + assert svg_dimensions(img) == dims |
| 67 | + # Check drawing options (clear background). |
| 68 | + drawOpts = rdMolDraw2D.MolDrawOptions() |
| 69 | + drawOpts.clearBackground = True |
| 70 | + img = vis_utils.smiles_to_svg(smi, draw_options=drawOpts) |
| 71 | + assert '</rect>' in img # <rect> exists with background |
| 72 | + drawOpts.clearBackground = False |
| 73 | + img = vis_utils.smiles_to_svg(smi, draw_options=drawOpts) |
| 74 | + assert '</rect>' not in img |
| 75 | + |
| 76 | + |
| 77 | +@suppress_rdlogger() |
| 78 | +def test_smiles_to_image(): |
| 79 | + # These aren't paticularly great tests for this function... |
| 80 | + smi = 'Cc1ccc(NC(=O)c2ccc(CN3CCN(C)CC3)cc2)cc1Nc1nccc(-c2cccnc2)n1' |
| 81 | + img = vis_utils.smiles_to_image(smi) # smiles to SVG |
| 82 | + assert img is not None |
| 83 | + assert img != 'data:image/svg+xml;charset=utf-8,' |
| 84 | + null_smi = 'xxx' |
| 85 | + img = vis_utils.smiles_to_image(null_smi) |
| 86 | + assert img == 'data:image/svg+xml;charset=utf-8,' |
| 87 | + |
| 88 | + |
| 89 | +def test_embed_node_mol_images(network): |
| 90 | + # Embed images into node attributes. |
| 91 | + vis_utils.embed_node_mol_images(network) |
| 92 | + for _, data in network.nodes(data=True): |
| 93 | + img = data.get('img', None) |
| 94 | + assert img is not None |
| 95 | + # Remove images from node attributes. |
| 96 | + vis_utils.remove_node_mol_images(network) |
| 97 | + for _, data in network.nodes(data=True): |
| 98 | + img = data.get('img', None) |
| 99 | + assert img is None |
| 100 | + |
| 101 | + |
| 102 | +def test_color_nodes_by_attribute(network): |
| 103 | + key = 'attr' |
| 104 | + insert_random_node_attribute(network, key) |
| 105 | + # Color scaffold nodes. |
| 106 | + vis_utils.color_scaffold_nodes_by_attribute(network, key, 'BuPu') |
| 107 | + for _, data in network.get_scaffold_nodes(data=True): |
| 108 | + c = data.get('color', None) |
| 109 | + assert c is not None |
| 110 | + assert is_valid_hex(c) |
| 111 | + # Color molecule nodes. |
| 112 | + cmap = plt.get_cmap('hot') |
| 113 | + vis_utils.color_molecule_nodes_by_attribute(network, key, cmap, 'col') |
| 114 | + for _, data in network.get_molecule_nodes(data=True): |
| 115 | + c = data.get('col', None) |
| 116 | + assert c is not None |
| 117 | + assert is_valid_hex(c) |
| 118 | + |
| 119 | + |
| 120 | +def test_root_node(network): |
| 121 | + vis_utils.add_root_node(network) |
| 122 | + assert network.has_node('root') is True |
| 123 | + assert network.in_degree('root') == 0 |
| 124 | + vis_utils.remove_root_node(network) |
| 125 | + assert network.has_node('root') is False |
0 commit comments