Hack90 commited on
Commit
d8a3b21
·
verified ·
1 Parent(s): c60debb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +877 -144
  2. requirements.txt +5 -4
app.py CHANGED
@@ -1,151 +1,884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pathlib import Path
2
- from typing import List, Dict, Tuple
3
- import matplotlib.colors as mpl_colors
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import pandas as pd
6
- import seaborn as sns
7
- import shinyswatch
8
-
9
- from shiny import App, Inputs, Outputs, Session, reactive, render, req, ui
10
-
11
- sns.set_theme()
12
-
13
- www_dir = Path(__file__).parent.resolve() / "www"
14
-
15
- df = pd.read_csv(Path(__file__).parent / "penguins.csv", na_values="NA")
16
- numeric_cols: List[str] = df.select_dtypes(include=["float64"]).columns.tolist()
17
- species: List[str] = df["Species"].unique().tolist()
18
- species.sort()
19
-
20
- app_ui = ui.page_fillable(
21
- shinyswatch.theme.minty(),
22
- ui.layout_sidebar(
23
- ui.sidebar(
24
- # Artwork by @allison_horst
25
- ui.input_selectize(
26
- "xvar",
27
- "X variable",
28
- numeric_cols,
29
- selected="Bill Length (mm)",
30
- ),
31
- ui.input_selectize(
32
- "yvar",
33
- "Y variable",
34
- numeric_cols,
35
- selected="Bill Depth (mm)",
36
- ),
37
- ui.input_checkbox_group(
38
- "species", "Filter by species", species, selected=species
39
- ),
40
- ui.hr(),
41
- ui.input_switch("by_species", "Show species", value=True),
42
- ui.input_switch("show_margins", "Show marginal plots", value=True),
43
- ),
44
- ui.output_ui("value_boxes"),
45
- ui.output_plot("scatter", fill=True),
46
- ui.help_text(
47
- "Artwork by ",
48
- ui.a("@allison_horst", href="https://twitter.com/allison_horst"),
49
- class_="text-end",
50
- ),
51
- ),
52
- )
53
-
54
-
55
- def server(input: Inputs, output: Outputs, session: Session):
56
- @reactive.Calc
57
- def filtered_df() -> pd.DataFrame:
58
- """Returns a Pandas data frame that includes only the desired rows"""
59
-
60
- # This calculation "req"uires that at least one species is selected
61
- req(len(input.species()) > 0)
62
-
63
- # Filter the rows so we only include the desired species
64
- return df[df["Species"].isin(input.species())]
65
-
66
- @output
67
- @render.plot
68
- def scatter():
69
- """Generates a plot for Shiny to display to the user"""
70
-
71
- # The plotting function to use depends on whether margins are desired
72
- plotfunc = sns.jointplot if input.show_margins() else sns.scatterplot
73
-
74
- plotfunc(
75
- data=filtered_df(),
76
- x=input.xvar(),
77
- y=input.yvar(),
78
- palette=palette,
79
- hue="Species" if input.by_species() else None,
80
- hue_order=species,
81
- legend=False,
82
- )
83
-
84
- @output
85
- @render.ui
86
- def value_boxes():
87
- df = filtered_df()
88
-
89
- def penguin_value_box(title: str, count: int, bgcol: str, showcase_img: str):
90
- return ui.value_box(
91
- title,
92
- count,
93
- {"class_": "pt-1 pb-0"},
94
- showcase=ui.fill.as_fill_item(
95
- ui.tags.img(
96
- {"style": "object-fit:contain;"},
97
- src=showcase_img,
98
- )
99
- ),
100
- theme_color=None,
101
- style=f"background-color: {bgcol};",
102
- )
103
-
104
- if not input.by_species():
105
- return penguin_value_box(
106
- "Penguins",
107
- len(df.index),
108
- bg_palette["default"],
109
- # Artwork by @allison_horst
110
- showcase_img="penguins.png",
111
- )
112
-
113
- value_boxes = [
114
- penguin_value_box(
115
- name,
116
- len(df[df["Species"] == name]),
117
- bg_palette[name],
118
- # Artwork by @allison_horst
119
- showcase_img=f"{name}.png",
120
- )
121
- for name in species
122
- # Only include boxes for _selected_ species
123
- if name in input.species()
124
- ]
125
-
126
- return ui.layout_column_wrap(*value_boxes, width = 1 / len(value_boxes))
127
-
128
-
129
- # "darkorange", "purple", "cyan4"
130
- colors = [[255, 140, 0], [160, 32, 240], [0, 139, 139]]
131
- colors = [(r / 255.0, g / 255.0, b / 255.0) for r, g, b in colors]
132
-
133
- palette: Dict[str, Tuple[float, float, float]] = {
134
- "Adelie": colors[0],
135
- "Chinstrap": colors[1],
136
- "Gentoo": colors[2],
137
- "default": sns.color_palette()[0], # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  }
139
 
140
- bg_palette = {}
141
- # Use `sns.set_style("whitegrid")` to help find approx alpha value
142
- for name, col in palette.items():
143
- # Adjusted n_colors until `axe` accessibility did not complain about color contrast
144
- bg_palette[name] = mpl_colors.to_hex(sns.light_palette(col, n_colors=7)[1]) # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- app = App(
148
- app_ui,
149
- server,
150
- static_assets=str(www_dir),
151
- )
 
 
 
1
+ from shiny import render
2
+ from shiny.express import input, ui
3
+ from datasets import load_dataset
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import matplotlib
7
+ import numpy as np
8
+
9
+ ############################################################# 2D Line Plot ########################################################
10
+ ### dvq stuff, obvs this will just be an import in the final version
11
+ from typing import Dict, Optional
12
+ from collections import namedtuple
13
+ import numpy as np
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.style as mplstyle
16
  from pathlib import Path
17
+ # using a faster style for plotting
18
+ mplstyle.use('fast')
19
+
20
+ # Mapping of nucleotides to float coordinates
21
+ mapping_easy = {
22
+ 'A': np.array([0.5, -0.8660254037844386]),
23
+ 'T': np.array([0.5, 0.8660254037844386]),
24
+ 'G': np.array([0.8660254037844386, -0.5]),
25
+ 'C': np.array([0.8660254037844386, 0.5]),
26
+ 'N': np.array([0, 0])
27
+ }
28
+
29
+ # coordinates for x+iy
30
+ Coord = namedtuple("Coord", ["x","y"])
31
+
32
+ # coordinates for a CGR encoding
33
+ CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
34
+
35
+ # coordinates for each nucleotide in the 2d-plane
36
+ DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
37
+
38
+ # Function to convert a DNA sequence to a list of coordinates
39
+ def _dna_to_coordinates(dna_sequence, mapping):
40
+ dna_sequence = dna_sequence.upper()
41
+ coordinates = np.array([mapping.get(nucleotide, mapping['N']) for nucleotide in dna_sequence])
42
+ return coordinates
43
+
44
+ # Function to create the cumulative sum of a list of coordinates
45
+ def _get_cumulative_coords(mapped_coords):
46
+ cumulative_coords = np.cumsum(mapped_coords, axis=0)
47
+ return cumulative_coords
48
+
49
+ # Function to take a list of DNA sequences and plot them in a single figure
50
+ def plot_2d_sequences(dna_sequences, mapping=mapping_easy, single_sequence=False):
51
+ fig, ax = plt.subplots()
52
+ if single_sequence:
53
+ dna_sequences = [dna_sequences]
54
+ for dna_sequence in dna_sequences:
55
+ mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
56
+ cumulative_coords = _get_cumulative_coords(mapped_coords)
57
+ ax.plot(*cumulative_coords.T)
58
+ return fig
59
+
60
+ # Function to plot a comparison of DNA sequences
61
+ def plot_2d_comparison(dna_sequences_grouped, labels, mapping=mapping_easy):
62
+ fig, ax = plt.subplots()
63
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(dna_sequences_grouped)))
64
+ for count, (dna_sequences, color) in enumerate(zip(dna_sequences_grouped, colors)):
65
+ for dna_sequence in dna_sequences:
66
+ mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
67
+ cumulative_coords = _get_cumulative_coords(mapped_coords)
68
+ ax.plot(*cumulative_coords.T, color=color, label=labels[count])
69
+ # Only show unique labels in the legend
70
+ handles, labels = ax.get_legend_handles_labels()
71
+ by_label = dict(zip(labels, handles))
72
+ ax.legend(by_label.values(), by_label.keys())
73
+ return fig
74
+
75
+
76
+ ############################################################# Virus Dataset ########################################################
77
+ ds = load_dataset('Hack90/virus_tiny')
78
+ df = pd.DataFrame(ds['train'])
79
+ virus = df['Organism_Name'].unique()
80
+ virus = {v: v for v in virus}
81
+
82
+ ############################################################# Filter and Select ########################################################
83
+ def filter_and_select(group):
84
+ if len(group) >= 3:
85
+ return group.head(3)
86
+
87
+ ############################################################# Wens Method ########################################################
88
+ import numpy as np
89
+
90
+ WEIGHTS = {'0100': 1/6, '0101': 2/6, '1100' : 3/6, '0110':3/6, '1101': 4/6, '1110': 5/6,'0111':5/6, '1111': 6/6}
91
+ LOWEST_LENGTH = 5000
92
+
93
+ def _get_subsequences(sequence):
94
+ return {nuc: [i+1 for i, x in enumerate(sequence) if x == nuc] for nuc in 'ACTG'}
95
+
96
+ def _calculate_coordinates_fixed(subsequence, L=LOWEST_LENGTH):
97
+ return [((2 * np.pi / (L - 1)) * (K-1), np.sqrt((2 * np.pi / (L - 1)) * (K-1))) for K in subsequence]
98
+
99
+ def _calculate_weighting_full(sequence, WEIGHTS, L=LOWEST_LENGTH, E=0.0375):
100
+ weightings = [0]
101
+ for i in range(1, len(sequence) - 1):
102
+ if i < len(sequence) - 2:
103
+ subsequence = sequence[i-1:i+3]
104
+ comparison_pattern = f"{'1' if subsequence[0] == subsequence[1] else '0'}1{'1' if subsequence[2] == subsequence[1] else '0'}{'1' if subsequence[3] == subsequence[1] else '0'}"
105
+ weight = WEIGHTS.get(comparison_pattern, 0)
106
+ weight = weight * E if i > L else weight
107
+ else:
108
+ weight = 0
109
+ weightings.append(weight)
110
+ weightings.append(0)
111
+ return weightings
112
+
113
+ def _centre_of_mass(polar_coordinates, weightings):
114
+ x, y = _calculate_standard_coordinates(polar_coordinates)
115
+ return sum(weightings[i] * ((x[i] - (x[i]*weightings[i]))**2 + (y[i] - y[i]*weightings[i])**2) for i in range(len(x)))
116
+
117
+ def _normalised_moment_of_inertia(polar_coordinates, weightings):
118
+ moment = _centre_of_mass(polar_coordinates, weightings)
119
+ return np.sqrt(moment / sum(weightings))
120
+
121
+ def _calculate_standard_coordinates(polar_coordinates):
122
+ return [rho * np.cos(theta) for theta, rho in polar_coordinates], [rho * np.sin(theta) for theta, rho in polar_coordinates]
123
+
124
+
125
+ def _moments_of_inertia(polar_coordinates, weightings):
126
+ return [_normalised_moment_of_inertia(indices, weightings) for subsequence, indices in polar_coordinates.items()]
127
+
128
+ def moment_of_inertia(sequence, WEIGHTS, L=5000, E=0.0375):
129
+ subsequences = _get_subsequences(sequence)
130
+ polar_coordinates = {subsequence: _calculate_coordinates_fixed(indices, len(sequence)) for subsequence, indices in subsequences.items()}
131
+ weightings = _calculate_weighting_full(sequence, WEIGHTS, L=L, E=E)
132
+ return _moments_of_inertia(polar_coordinates, weightings)
133
+
134
+
135
+ def similarity_wen(sequence1, sequence2, WEIGHTS, L=5000, E=0.0375):
136
+ L = min(len(sequence1), len(sequence2))
137
+ inertia1 = moment_of_inertia(sequence1, WEIGHTS, L=L, E=E)
138
+ inertia2 = moment_of_inertia(sequence2, WEIGHTS, L=L, E=E)
139
+ similarity = np.sqrt(sum((x - y)**2 for x, y in zip(inertia1, inertia2)))
140
+ return similarity
141
+ def heatmap(data, row_labels, col_labels, ax=None,
142
+ cbar_kw=None, cbarlabel="", **kwargs):
143
+ """
144
+ Create a heatmap from a numpy array and two lists of labels.
145
+
146
+ Parameters
147
+ ----------
148
+ data
149
+ A 2D numpy array of shape (M, N).
150
+ row_labels
151
+ A list or array of length M with the labels for the rows.
152
+ col_labels
153
+ A list or array of length N with the labels for the columns.
154
+ ax
155
+ A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
156
+ not provided, use current axes or create a new one. Optional.
157
+ cbar_kw
158
+ A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
159
+ cbarlabel
160
+ The label for the colorbar. Optional.
161
+ **kwargs
162
+ All other arguments are forwarded to `imshow`.
163
+ """
164
+
165
+ if ax is None:
166
+ ax = plt.gca()
167
+
168
+ if cbar_kw is None:
169
+ cbar_kw = {}
170
+
171
+ # Plot the heatmap
172
+ im = ax.imshow(data, **kwargs)
173
+
174
+ # Create colorbar
175
+ cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
176
+ cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
177
+
178
+ # Show all ticks and label them with the respective list entries.
179
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
180
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
181
+
182
+ # Let the horizontal axes labeling appear on top.
183
+ ax.tick_params(top=True, bottom=False,
184
+ labeltop=True, labelbottom=False)
185
+
186
+ # Rotate the tick labels and set their alignment.
187
+ plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
188
+ rotation_mode="anchor")
189
+
190
+ # Turn spines off and create white grid.
191
+ ax.spines[:].set_visible(False)
192
+
193
+ ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
194
+ ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
195
+ ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
196
+ ax.tick_params(which="minor", bottom=False, left=False)
197
+
198
+ return im, cbar
199
+
200
 
201
+ def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
202
+ textcolors=("black", "white"),
203
+ threshold=None, **textkw):
204
+ """
205
+ A function to annotate a heatmap.
206
+
207
+ Parameters
208
+ ----------
209
+ im
210
+ The AxesImage to be labeled.
211
+ data
212
+ Data used to annotate. If None, the image's data is used. Optional.
213
+ valfmt
214
+ The format of the annotations inside the heatmap. This should either
215
+ use the string format method, e.g. "$ {x:.2f}", or be a
216
+ `matplotlib.ticker.Formatter`. Optional.
217
+ textcolors
218
+ A pair of colors. The first is used for values below a threshold,
219
+ the second for those above. Optional.
220
+ threshold
221
+ Value in data units according to which the colors from textcolors are
222
+ applied. If None (the default) uses the middle of the colormap as
223
+ separation. Optional.
224
+ **kwargs
225
+ All other arguments are forwarded to each call to `text` used to create
226
+ the text labels.
227
+ """
228
+
229
+ if not isinstance(data, (list, np.ndarray)):
230
+ data = im.get_array()
231
+
232
+ # Normalize the threshold to the images color range.
233
+ if threshold is not None:
234
+ threshold = im.norm(threshold)
235
+ else:
236
+ threshold = im.norm(data.max())/2.
237
+
238
+ # Set default alignment to center, but allow it to be
239
+ # overwritten by textkw.
240
+ kw = dict(horizontalalignment="center",
241
+ verticalalignment="center")
242
+ kw.update(textkw)
243
+
244
+ # Get the formatter in case a string is supplied
245
+ if isinstance(valfmt, str):
246
+ valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
247
+
248
+ # Loop over the data and create a `Text` for each "pixel".
249
+ # Change the text's color depending on the data.
250
+ texts = []
251
+ for i in range(data.shape[0]):
252
+ for j in range(data.shape[1]):
253
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
254
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
255
+ texts.append(text)
256
+
257
+ return texts
258
+
259
+ def wens_method_heatmap(df, virus_species):
260
+ # Create a dataframe to store the similarity values
261
+ similarity_df = pd.DataFrame(index=virus_species, columns=virus_species)
262
+ # Fill the dataframe with similarity values
263
+ for virus1 in virus_species:
264
+ for virus2 in virus_species:
265
+ if virus1 == virus2:
266
+ sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
267
+ sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[1]
268
+ similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
269
+ similarity_df.loc[virus1, virus2] = similarity
270
+ else:
271
+ sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
272
+ sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[0]
273
+ similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
274
+ similarity_df.loc[virus1, virus2] = similarity
275
+ similarity_df = similarity_df.apply(pd.to_numeric)
276
+
277
+ # Optional: Handle NaN values if your similarity computation might result in them
278
+ # similarity_df.fillna(0, inplace=True)
279
+
280
+ fig, ax = plt.subplots()
281
+ # Plotting
282
+ im = ax.imshow(similarity_df, cmap="YlGn")
283
+ ax.set_xticks(np.arange(len(virus_species)), labels=virus_species)
284
+ ax.set_yticks(np.arange(len(virus_species)), labels=virus_species)
285
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
286
+ cbar = ax.figure.colorbar(im, ax=ax)
287
+ cbar.ax.set_ylabel("Similarity", rotation=-90, va="bottom")
288
+
289
+
290
+ return fig
291
+
292
+
293
+ ############################################################# ColorSquare ########################################################
294
+ import math
295
+ import numpy as np
296
+ import matplotlib.pyplot as plt
297
+ from matplotlib.colors import ListedColormap
298
  import pandas as pd
299
+
300
+ def _fill_spiral(matrix, seq_colors, k):
301
+ left, top, right, bottom = 0, 0, k-1, k-1
302
+ index = 0
303
+ while left <= right and top <= bottom:
304
+ for i in range(left, right + 1): # Top row
305
+ if index < len(seq_colors):
306
+ matrix[top][i] = seq_colors[index]
307
+ index += 1
308
+ top += 1
309
+ for i in range(top, bottom + 1): # Right column
310
+ if index < len(seq_colors):
311
+ matrix[i][right] = seq_colors[index]
312
+ index += 1
313
+ right -= 1
314
+ for i in range(right, left - 1, -1): # Bottom row
315
+ if index < len(seq_colors):
316
+ matrix[bottom][i] = seq_colors[index]
317
+ index += 1
318
+ bottom -= 1
319
+ for i in range(bottom, top - 1, -1): # Left column
320
+ if index < len(seq_colors):
321
+ matrix[i][left] = seq_colors[index]
322
+ index += 1
323
+ left += 1
324
+
325
+
326
+ def _generate_color_square(sequence,virus, save=False, count=0, label=None):
327
+ # Define the sequence and corresponding colors with indices
328
+ colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4} # Assign indices to each color
329
+ seq_colors = [colors[char] for char in sequence.lower()] # Map the sequence to color indices
330
+
331
+ # Calculate k (size of the square)
332
+ k = math.ceil(math.sqrt(len(sequence)))
333
+
334
+ # Initialize a k x k matrix filled with the index for 'white'
335
+ matrix = np.full((k, k), colors['n'], dtype=int)
336
+
337
+ # Fill the matrix in a clockwise spiral
338
+ _fill_spiral(matrix, seq_colors, k)
339
+
340
+ # Define a custom color map for plotting
341
+ cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
342
+
343
+ # Plot the matrix
344
+ plt.figure(figsize=(5, 5))
345
+ plt.imshow(matrix, cmap=cmap, interpolation='nearest')
346
+ if label:
347
+ plt.title(label)
348
+ plt.axis('off') # Hide the axes
349
+ if save:
350
+ plt.savefig(f'color_square_{virus}_{count}.png', dpi=300, bbox_inches='tight')
351
+ # plt.show()
352
+
353
+ def plot_color_square(df, virus_species):
354
+ ncols = 3
355
+ nrows = len(virus_species)
356
+ fig, axeses = plt.subplots(
357
+ nrows=nrows,
358
+ ncols=ncols,
359
+ squeeze=False,
360
+ )
361
+ for i in range(0, ncols * nrows):
362
+ row = i // ncols
363
+ col = i % ncols
364
+ axes = axeses[row, col]
365
+ data = df[i]
366
+ virus = virus_species[row]
367
+ # Define the sequence and corresponding colors with indices
368
+ colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
369
+ # remove all non-nucleotide characters
370
+ data = ''.join([char for char in data.lower() if char in 'atcgn'])
371
+ # Assign indices to each color
372
+ seq_colors = [colors[char] for char in data.lower()] # Map the sequence to color indices
373
+
374
+ # Calculate k (size of the square)
375
+ k = math.ceil(math.sqrt(len(data)))
376
+
377
+ # Initialize a k x k matrix filled with the index for 'white'
378
+ matrix = np.full((k, k), colors['n'], dtype=int)
379
+
380
+ # Fill the matrix in a clockwise spiral
381
+ _fill_spiral(matrix, seq_colors, k)
382
+
383
+ # Define a custom color map for plotting
384
+ cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
385
+ axes.imshow(matrix, cmap=cmap, interpolation='nearest')
386
+ axes.set_title(virus)
387
+ return fig
388
+
389
+
390
+
391
+ def generate_color_square(sequence,virus, multi=False, save=False, label=None):
392
+ if multi:
393
+ for i,seq in enumerate(sequence):
394
+ _generate_color_square(seq, virus,save, i, label[i] if label else None)
395
+ else:
396
+ _generate_color_square(sequence, save, label=label)
397
+
398
+
399
+ ############################################################# FCGR ########################################################
400
+
401
+ from typing import Dict, Optional
402
+ from collections import namedtuple
403
+
404
+ # coordinates for x+iy
405
+ Coord = namedtuple("Coord", ["x","y"])
406
+
407
+ # coordinates for a CGR encoding
408
+ CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
409
+
410
+ # coordinates for each nucleotide in the 2d-plane
411
+ DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
412
+
413
+ class CGR:
414
+ "Chaos Game Representation for DNA"
415
+ def __init__(self, coords: Optional[Dict[chr,tuple]]=None):
416
+ self.nucleotide_coords = DEFAULT_COORDS if coords is None else coords
417
+ self.cgr_coords = CGRCoords(0,0,0)
418
+
419
+ def nucleotide_by_coords(self,x,y):
420
+ "Get nucleotide by coordinates (x,y)"
421
+ # filter nucleotide by coordinates
422
+ filtered = dict(filter(lambda item: item[1] == Coord(x,y), self.nucleotide_coords.items()))
423
+
424
+ return list(filtered.keys())[0]
425
+
426
+ def forward(self, nucleotide: str):
427
+ "Compute next CGR coordinates"
428
+ x = (self.cgr_coords.x + self.nucleotide_coords.get(nucleotide).x)/2
429
+ y = (self.cgr_coords.y + self.nucleotide_coords.get(nucleotide).y)/2
430
+
431
+ # update cgr_coords
432
+ self.cgr_coords = CGRCoords(self.cgr_coords.N+1,x,y)
433
+
434
+ def backward(self,):
435
+ "Compute last CGR coordinates. Current nucleotide can be inferred from (x,y)"
436
+ # get current nucleotide based on coordinates
437
+ n_x,n_y = self.coords_current_nucleotide()
438
+ nucleotide = self.nucleotide_by_coords(n_x,n_y)
439
+
440
+ # update coordinates to the previous one
441
+ x = 2*self.cgr_coords.x - n_x
442
+ y = 2*self.cgr_coords.y - n_y
443
+
444
+ # update cgr_coords
445
+ self.cgr_coords = CGRCoords(self.cgr_coords.N-1,x,y)
446
+
447
+ return nucleotide
448
+
449
+ def coords_current_nucleotide(self,):
450
+ x = 1 if self.cgr_coords.x>0 else -1
451
+ y = 1 if self.cgr_coords.y>0 else -1
452
+ return x,y
453
+
454
+ def encode(self, sequence: str):
455
+ "From DNA sequence to CGR"
456
+ # reset starting position to (0,0,0)
457
+ self.reset_coords()
458
+ for nucleotide in sequence:
459
+ self.forward(nucleotide)
460
+ return self.cgr_coords
461
+
462
+ def reset_coords(self,):
463
+ self.cgr_coords = CGRCoords(0,0,0)
464
+
465
+ def decode(self, N:int, x:int, y:int)->str:
466
+ "From CGR to DNA sequence"
467
+ self.cgr_coords = CGRCoords(N,x,y)
468
+
469
+ # decoded sequence
470
+ sequence = []
471
+
472
+ # Recover the entire genome
473
+ while self.cgr_coords.N>0:
474
+ nucleotide = self.backward()
475
+ sequence.append(nucleotide)
476
+ return "".join(sequence[::-1])
477
+
478
+
479
+ from itertools import product
480
+ from collections import defaultdict
481
+ import numpy as np
482
+
483
+ class FCGR(CGR):
484
+ """Frequency matrix CGR
485
+ an (2**k x 2**k) 2D representation will be created for a
486
+ n-long sequence.
487
+ - k represents the k-mer.
488
+ - 2**k x 2**k = 4**k the total number of k-mers (sequences of length k)
489
+ - pixel value correspond to the value of the frequency for each k-mer
490
+ """
491
+
492
+ def __init__(self, k: int,):
493
+ super().__init__()
494
+ self.k = k # k-mer representation
495
+ self.kmers = list("".join(kmer) for kmer in product("ACGT", repeat=self.k))
496
+ self.kmer2pixel = self.kmer2pixel_position()
497
+
498
+ def __call__(self, sequence: str):
499
+ "Given a DNA sequence, returns an array with his frequencies in the same order as FCGR"
500
+ self.count_kmers(sequence)
501
+
502
+ # Create an empty array to save the FCGR values
503
+ array_size = int(2**self.k)
504
+ freq_matrix = np.zeros((array_size,array_size))
505
+
506
+ # Assign frequency to each box in the matrix
507
+ for kmer, freq in self.freq_kmer.items():
508
+ pos_x, pos_y = self.kmer2pixel[kmer]
509
+ freq_matrix[int(pos_x)-1,int(pos_y)-1] = freq
510
+ return freq_matrix
511
+
512
+ def count_kmer(self, kmer):
513
+ if "N" not in kmer:
514
+ self.freq_kmer[kmer] += 1
515
+
516
+ def count_kmers(self, sequence: str):
517
+ self.freq_kmer = defaultdict(int)
518
+ # representativity of kmers
519
+ last_j = len(sequence) - self.k + 1
520
+ kmers = (sequence[i:(i+self.k)] for i in range(last_j))
521
+ # count kmers in a dictionary
522
+ list(self.count_kmer(kmer) for kmer in kmers)
523
+
524
+ def kmer_probabilities(self, sequence: str):
525
+ self.probabilities = defaultdict(float)
526
+ N=len(sequence)
527
+ for key, value in self.freq_kmer.items():
528
+ self.probabilities[key] = float(value) / (N - self.k + 1)
529
+
530
+ def pixel_position(self, kmer: str):
531
+ "Get pixel position in the FCGR matrix for a k-mer"
532
+
533
+ coords = self.encode(kmer)
534
+ N,x,y = coords.N, coords.x, coords.y
535
+
536
+ # Coordinates from [-1,1]² to [1,2**k]²
537
+ np_coords = np.array([(x + 1)/2, (y + 1)/2]) # move coordinates from [-1,1]² to [0,1]²
538
+ np_coords *= 2**self.k # rescale coordinates from [0,1]² to [0,2**k]²
539
+ x,y = np.ceil(np_coords) # round to upper integer
540
+
541
+ # Turn coordinates (cx,cy) into pixel (px,py) position
542
+ # px = 2**k-cy+1, py = cx
543
+ return 2**self.k-int(y)+1, int(x)
544
+
545
+ def kmer2pixel_position(self,):
546
+ kmer2pixel = dict()
547
+ for kmer in self.kmers:
548
+ kmer2pixel[kmer] = self.pixel_position(kmer)
549
+ return kmer2pixel
550
+
551
+
552
+ from tqdm import tqdm
553
+ from pathlib import Path
554
+
555
+ import numpy as np
556
+
557
+
558
+ class GenerateFCGR:
559
+ def __init__(self, kmer: int = 5, ):
560
+ self.kmer = kmer
561
+ self.fcgr = FCGR(kmer)
562
+ self.counter = 0 # count number of time a sequence is converted to fcgr
563
+
564
+
565
+ def __call__(self, list_fasta,):
566
+
567
+ for fasta in tqdm(list_fasta, desc="Generating FCGR"):
568
+ self.from_fasta(fasta)
569
+
570
+
571
+
572
+
573
+ def from_seq(self, seq: str):
574
+ "Get FCGR from a sequence"
575
+ seq = self.preprocessing(seq)
576
+ chaos = self.fcgr(seq)
577
+ self.counter +=1
578
+ return chaos
579
+
580
+ def reset_counter(self,):
581
+ self.counter=0
582
+
583
+ @staticmethod
584
+ def preprocessing(seq):
585
+ seq = seq.upper()
586
+ for letter in seq:
587
+ if letter not in "ATCG":
588
+ seq = seq.replace(letter,"N")
589
+ return seq
590
+
591
+ def plot_fcgr(df, virus_species):
592
+ ncols = 3
593
+ nrows = len(virus_species)
594
+ fig, axeses = plt.subplots(
595
+ nrows=nrows,
596
+ ncols=ncols,
597
+ squeeze=False,
598
+ )
599
+ for i in range(0, ncols * nrows):
600
+ row = i // ncols
601
+ col = i % ncols
602
+ axes = axeses[row, col]
603
+ data = df[i].upper()
604
+ chaos = GenerateFCGR().from_seq(seq=data)
605
+ virus = virus_species[row]
606
+ axes.imshow(chaos)
607
+ axes.set_title(virus)
608
+ return fig
609
+
610
+ ############################################################# Persistant Homology ########################################################
611
+ import numpy as np
612
+ import persim
613
+ import ripser
614
+ import matplotlib.pyplot as plt
615
+
616
+ NUCLEOTIDE_MAPPING = {
617
+ 'a': np.array([1, 0, 0, 0]),
618
+ 'c': np.array([0, 1, 0, 0]),
619
+ 'g': np.array([0, 0, 1, 0]),
620
+ 't': np.array([0, 0, 0, 1])
621
  }
622
 
623
+ def encode_nucleotide_to_vector(nucleotide):
624
+ return NUCLEOTIDE_MAPPING.get(nucleotide)
625
+
626
+ def chaos_4d_representation(dna_sequence):
627
+ points = [encode_nucleotide_to_vector(dna_sequence[0])]
628
+ for nucleotide in dna_sequence[1:]:
629
+ vector = encode_nucleotide_to_vector(nucleotide)
630
+ if vector is None:
631
+ continue
632
+ next_point = 0.5 * (points[-1] + vector)
633
+ points.append(next_point)
634
+ return np.array(points)
635
+
636
+ def persistence_homology(dna_sequence, multi=False, plot=False, sample_rate=7):
637
+ if multi:
638
+ c4dr_points = np.array([chaos_4d_representation(sequence) for sequence in dna_sequence])
639
+ dgm_dna = [ripser.ripser(points[::sample_rate], maxdim=1)['dgms'] for points in c4dr_points]
640
+ if plot:
641
+ persim.plot_diagrams([dgm[1] for dgm in dgm_dna], labels=[f'sequence {i}' for i in range(len(dna_sequence))])
642
+ else:
643
+ c4dr_points = chaos_4d_representation(dna_sequence)
644
+ dgm_dna = ripser.ripser(c4dr_points[::sample_rate], maxdim=1)['dgms']
645
+ if plot:
646
+ persim.plot_diagrams(dgm_dna[1])
647
+ return dgm_dna
648
+
649
+ def plot_diagrams(
650
+ diagrams,
651
+ plot_only=None,
652
+ title=None,
653
+ xy_range=None,
654
+ labels=None,
655
+ colormap="default",
656
+ size=20,
657
+ ax_color=np.array([0.0, 0.0, 0.0]),
658
+ diagonal=True,
659
+ lifetime=False,
660
+ legend=True,
661
+ show=False,
662
+ ax=None
663
+ ):
664
+ """A helper function to plot persistence diagrams.
665
+
666
+ Parameters
667
+ ----------
668
+ diagrams: ndarray (n_pairs, 2) or list of diagrams
669
+ A diagram or list of diagrams. If diagram is a list of diagrams,
670
+ then plot all on the same plot using different colors.
671
+ plot_only: list of numeric
672
+ If specified, an array of only the diagrams that should be plotted.
673
+ title: string, default is None
674
+ If title is defined, add it as title of the plot.
675
+ xy_range: list of numeric [xmin, xmax, ymin, ymax]
676
+ User provided range of axes. This is useful for comparing
677
+ multiple persistence diagrams.
678
+ labels: string or list of strings
679
+ Legend labels for each diagram.
680
+ If none are specified, we use H_0, H_1, H_2,... by default.
681
+ colormap: string, default is 'default'
682
+ Any of matplotlib color palettes.
683
+ Some options are 'default', 'seaborn', 'sequential'.
684
+ See all available styles with
685
+
686
+ .. code:: python
687
+
688
+ import matplotlib as mpl
689
+ print(mpl.styles.available)
690
+
691
+ size: numeric, default is 20
692
+ Pixel size of each point plotted.
693
+ ax_color: any valid matplotlib color type.
694
+ See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
695
+ diagonal: bool, default is True
696
+ Plot the diagonal x=y line.
697
+ lifetime: bool, default is False. If True, diagonal is turned to False.
698
+ Plot life time of each point instead of birth and death.
699
+ Essentially, visualize (x, y-x).
700
+ legend: bool, default is True
701
+ If true, show the legend.
702
+ show: bool, default is False
703
+ Call plt.show() after plotting. If you are using self.plot() as part
704
+ of a subplot, set show=False and call plt.show() only once at the end.
705
+ """
706
+
707
+ fig, ax = plt.subplots() if ax is None else ax
708
+ plt.style.use(colormap)
709
+
710
+ xlabel, ylabel = "Birth", "Death"
711
+
712
+ if not isinstance(diagrams, list):
713
+ # Must have diagrams as a list for processing downstream
714
+ diagrams = [diagrams]
715
+
716
+ if labels is None:
717
+ # Provide default labels for diagrams if using self.dgm_
718
+ labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)]
719
+
720
+ if plot_only:
721
+ diagrams = [diagrams[i] for i in plot_only]
722
+ labels = [labels[i] for i in plot_only]
723
+
724
+ if not isinstance(labels, list):
725
+ labels = [labels] * len(diagrams)
726
+
727
+ # Construct copy with proper type of each diagram
728
+ # so we can freely edit them.
729
+ diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
730
+
731
+ # find min and max of all visible diagrams
732
+ concat_dgms = np.concatenate(diagrams).flatten()
733
+ has_inf = np.any(np.isinf(concat_dgms))
734
+ finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
735
+
736
+ # clever bounding boxes of the diagram
737
+ if not xy_range:
738
+ # define bounds of diagram
739
+ ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
740
+ x_r = ax_max - ax_min
741
+
742
+ # Give plot a nice buffer on all sides.
743
+ # ax_range=0 when only one point,
744
+ buffer = 1 if xy_range == 0 else x_r / 5
745
+
746
+ x_down = ax_min - buffer / 2
747
+ x_up = ax_max + buffer
748
+
749
+ y_down, y_up = x_down, x_up
750
+ else:
751
+ x_down, x_up, y_down, y_up = xy_range
752
+
753
+ yr = y_up - y_down
754
+
755
+ if lifetime:
756
+
757
+ # Don't plot landscape and diagonal at the same time.
758
+ diagonal = False
759
+
760
+ # reset y axis so it doesn't go much below zero
761
+ y_down = -yr * 0.05
762
+ y_up = y_down + yr
763
+
764
+ # set custom ylabel
765
+ ylabel = "Lifetime"
766
+
767
+ # set diagrams to be (x, y-x)
768
+ for dgm in diagrams:
769
+ dgm[:, 1] -= dgm[:, 0]
770
+
771
+ # plot horizon line
772
+ ax.plot([x_down, x_up], [0, 0], c=ax_color)
773
+
774
+ # Plot diagonal
775
+ if diagonal:
776
+ ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
777
+
778
+ # Plot inf line
779
+ if has_inf:
780
+ # put inf line slightly below top
781
+ b_inf = y_down + yr * 0.95
782
+ ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
783
+
784
+ # convert each inf in each diagram with b_inf
785
+ for dgm in diagrams:
786
+ dgm[np.isinf(dgm)] = b_inf
787
+
788
+ # Plot each diagram
789
+ for dgm, label in zip(diagrams, labels):
790
+
791
+ # plot persistence pairs
792
+ ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none")
793
+
794
+ ax.set_xlabel(xlabel)
795
+ ax.set_ylabel(ylabel)
796
+
797
+ ax.set_xlim([x_down, x_up])
798
+ ax.set_ylim([y_down, y_up])
799
+ ax.set_aspect('equal', 'box')
800
+
801
+ if title is not None:
802
+ ax.set_title(title)
803
+
804
+ if legend is True:
805
+ ax.legend(loc="lower right")
806
+
807
+ if show is True:
808
+ plt.show()
809
+ return fig, ax
810
+
811
+
812
+ def plot_persistence_homology(df, virus_species):
813
+ # if len(virus_species.unique()) > 1:
814
+ c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
815
+ dgm_dna = [ripser.ripser(points[::15], maxdim=1)['dgms'] for points in c4dr_points]
816
+ labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
817
+ fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
818
+ # else:
819
+ # c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
820
+ # dgm_dna = [ripser.ripser(points[::10], maxdim=1)['dgms'] for points in c4dr_points]
821
+ # labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
822
+ # print(labels)
823
+ # print(len(dgm_dna))
824
+ # fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
825
+ return fig
826
+
827
+ def compare_persistence_homology(dna_sequence1, dna_sequence2):
828
+ dgm_dna1 = persistence_homology(dna_sequence1)
829
+ dgm_dna2 = persistence_homology(dna_sequence2)
830
+ distance = persim.sliced_wasserstein(dgm_dna1[1], dgm_dna2[1])
831
+ return distance
832
 
833
+ ############################################################# UI #################################################################
834
+ ui.page_opts(fillable=True)
835
+ ui.panel_title("How similar are viruses, to themselves and others? Do they have undelying structure?")
836
+ with ui.layout_columns():
837
+ with ui.card():
838
+ ui.input_selectize(
839
+ "virus_selector",
840
+ "Select your viruses:",
841
+ virus,
842
+ multiple=True,
843
+ )
844
+ with ui.card():
845
+ ui.input_selectize(
846
+ "plot_type",
847
+ "Select your method:",
848
+ ["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
849
+ multiple=False,
850
+ )
851
+
852
+ ############################################################# Plotting ########################################################
853
+ here = Path(__file__).parent
854
+ @render.plot
855
+ def plot():
856
+ ds = load_dataset('Hack90/virus_tiny')
857
+ df = pd.DataFrame(ds['train'])
858
+ df = df[df['Organism_Name'].isin(input.virus_selector())]
859
+ # group by virus
860
+ grouped = df.groupby('Organism_Name')['Sequence'].apply(list)
861
+ # plot the comparison
862
+ fig = None
863
+ if input.plot_type() == "2D Line":
864
+ fig = plot_2d_comparison(grouped, grouped.index)
865
+ if input.plot_type() == "ColorSquare":
866
+ filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
867
+ fig = plot_color_square(filtered_df['Sequence'], filtered_df['Organism_Name'].unique())
868
+ if input.plot_type() == "Wens Method":
869
+ fig = wens_method_heatmap(df, df['Organism_Name'].unique())
870
+ if input.plot_type() == "Chaos Game Representation":
871
+ filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
872
+ fig = plot_fcgr(filtered_df['Sequence'], df['Organism_Name'].unique())
873
+ if input.plot_type() == "Persistant Homology":
874
+ filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
875
+ fig = plot_persistence_homology(filtered_df['Sequence'], filtered_df['Organism_Name'])
876
+ return fig
877
 
878
+ # @render.image
879
+ # def image():
880
+ # img = None
881
+ # if input.plot_type() == "ColorSquare":
882
+ # img = {"src": f"color_square_{input.virus_selector()[0]}_0.png", "alt": "ColorSquare"}
883
+ # return img
884
+ # return img
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
- shiny==0.7.0
2
- shinyswatch==0.4.2
3
- seaborn==0.12.2
4
- matplotlib==3.7.1
 
 
1
+ matplotlib
2
+ pandas
3
+ datasets
4
+ dvq
5
+ shiny