Hack90 commited on
Commit
293c610
·
verified ·
1 Parent(s): 3590429

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +823 -0
utils.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from shiny import render
2
+ from shiny.express import input, output, ui
3
+ from datasets import load_dataset
4
+ import pandas as pd
5
+ from pathlib import Path
6
+ import matplotlib
7
+ import numpy as np
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.style as mplstyle
11
+ from scipy.interpolate import interp1d
12
+ from typing import Dict, Optional
13
+ from collections import namedtuple
14
+
15
+
16
+ # Mapping of nucleotides to float coordinates
17
+ mapping_easy = {
18
+ 'A': np.array([0.5, -0.8660254037844386]),
19
+ 'T': np.array([0.5, 0.8660254037844386]),
20
+ 'G': np.array([0.8660254037844386, -0.5]),
21
+ 'C': np.array([0.8660254037844386, 0.5]),
22
+ 'N': np.array([0, 0])
23
+ }
24
+
25
+ # coordinates for x+iy
26
+ Coord = namedtuple("Coord", ["x","y"])
27
+
28
+ # coordinates for a CGR encoding
29
+ CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
30
+
31
+ # coordinates for each nucleotide in the 2d-plane
32
+ DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
33
+
34
+ # Function to convert a DNA sequence to a list of coordinates
35
+ def _dna_to_coordinates(dna_sequence, mapping):
36
+ dna_sequence = dna_sequence.upper()
37
+ coordinates = np.array([mapping.get(nucleotide, mapping['N']) for nucleotide in dna_sequence])
38
+ return coordinates
39
+
40
+ # Function to create the cumulative sum of a list of coordinates
41
+ def _get_cumulative_coords(mapped_coords):
42
+ cumulative_coords = np.cumsum(mapped_coords, axis=0)
43
+ return cumulative_coords
44
+
45
+ # Function to take a list of DNA sequences and plot them in a single figure
46
+ def plot_2d_sequences(dna_sequences, mapping=mapping_easy, single_sequence=False):
47
+ fig, ax = plt.subplots()
48
+ if single_sequence:
49
+ dna_sequences = [dna_sequences]
50
+ for dna_sequence in dna_sequences:
51
+ mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
52
+ cumulative_coords = _get_cumulative_coords(mapped_coords)
53
+ ax.plot(*cumulative_coords.T)
54
+ return fig
55
+
56
+ # Function to plot a comparison of DNA sequences
57
+ def plot_2d_comparison(dna_sequences_grouped, labels, mapping=mapping_easy):
58
+ fig, ax = plt.subplots()
59
+ colors = plt.cm.rainbow(np.linspace(0, 1, len(dna_sequences_grouped)))
60
+ for count, (dna_sequences, color) in enumerate(zip(dna_sequences_grouped, colors)):
61
+ for dna_sequence in dna_sequences:
62
+ mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
63
+ cumulative_coords = _get_cumulative_coords(mapped_coords)
64
+ ax.plot(*cumulative_coords.T, color=color, label=labels[count])
65
+ # Only show unique labels in the legend
66
+ handles, labels = ax.get_legend_handles_labels()
67
+ by_label = dict(zip(labels, handles))
68
+ ax.legend(by_label.values(), by_label.keys())
69
+ return fig
70
+
71
+
72
+ ############################################################# Virus Dataset ########################################################
73
+ #ds = load_dataset('Hack90/virus_tiny')
74
+ df = pd.read_parquet('virus_ds.parquet')
75
+ virus = df['Organism_Name'].unique()
76
+ virus = {v: v for v in virus}
77
+
78
+ ############################################################# Filter and Select ########################################################
79
+ def filter_and_select(group):
80
+ if len(group) >= 3:
81
+ return group.head(3)
82
+
83
+ ############################################################# Wens Method ########################################################
84
+ import numpy as np
85
+
86
+ WEIGHTS = {'0100': 1/6, '0101': 2/6, '1100' : 3/6, '0110':3/6, '1101': 4/6, '1110': 5/6,'0111':5/6, '1111': 6/6}
87
+ LOWEST_LENGTH = 5000
88
+
89
+ def _get_subsequences(sequence):
90
+ return {nuc: [i+1 for i, x in enumerate(sequence) if x == nuc] for nuc in 'ACTG'}
91
+
92
+ def _calculate_coordinates_fixed(subsequence, L=LOWEST_LENGTH):
93
+ return [((2 * np.pi / (L - 1)) * (K-1), np.sqrt((2 * np.pi / (L - 1)) * (K-1))) for K in subsequence]
94
+
95
+ def _calculate_weighting_full(sequence, WEIGHTS, L=LOWEST_LENGTH, E=0.0375):
96
+ weightings = [0]
97
+ for i in range(1, len(sequence) - 1):
98
+ if i < len(sequence) - 2:
99
+ subsequence = sequence[i-1:i+3]
100
+ comparison_pattern = f"{'1' if subsequence[0] == subsequence[1] else '0'}1{'1' if subsequence[2] == subsequence[1] else '0'}{'1' if subsequence[3] == subsequence[1] else '0'}"
101
+ weight = WEIGHTS.get(comparison_pattern, 0)
102
+ weight = weight * E if i > L else weight
103
+ else:
104
+ weight = 0
105
+ weightings.append(weight)
106
+ weightings.append(0)
107
+ return weightings
108
+
109
+ def _centre_of_mass(polar_coordinates, weightings):
110
+ x, y = _calculate_standard_coordinates(polar_coordinates)
111
+ return sum(weightings[i] * ((x[i] - (x[i]*weightings[i]))**2 + (y[i] - y[i]*weightings[i])**2) for i in range(len(x)))
112
+
113
+ def _normalised_moment_of_inertia(polar_coordinates, weightings):
114
+ moment = _centre_of_mass(polar_coordinates, weightings)
115
+ return np.sqrt(moment / sum(weightings))
116
+
117
+ def _calculate_standard_coordinates(polar_coordinates):
118
+ return [rho * np.cos(theta) for theta, rho in polar_coordinates], [rho * np.sin(theta) for theta, rho in polar_coordinates]
119
+
120
+
121
+ def _moments_of_inertia(polar_coordinates, weightings):
122
+ return [_normalised_moment_of_inertia(indices, weightings) for subsequence, indices in polar_coordinates.items()]
123
+
124
+ def moment_of_inertia(sequence, WEIGHTS, L=5000, E=0.0375):
125
+ subsequences = _get_subsequences(sequence)
126
+ polar_coordinates = {subsequence: _calculate_coordinates_fixed(indices, len(sequence)) for subsequence, indices in subsequences.items()}
127
+ weightings = _calculate_weighting_full(sequence, WEIGHTS, L=L, E=E)
128
+ return _moments_of_inertia(polar_coordinates, weightings)
129
+
130
+
131
+ def similarity_wen(sequence1, sequence2, WEIGHTS, L=5000, E=0.0375):
132
+ L = min(len(sequence1), len(sequence2))
133
+ inertia1 = moment_of_inertia(sequence1, WEIGHTS, L=L, E=E)
134
+ inertia2 = moment_of_inertia(sequence2, WEIGHTS, L=L, E=E)
135
+ similarity = np.sqrt(sum((x - y)**2 for x, y in zip(inertia1, inertia2)))
136
+ return similarity
137
+ def heatmap(data, row_labels, col_labels, ax=None,
138
+ cbar_kw=None, cbarlabel="", **kwargs):
139
+ """
140
+ Create a heatmap from a numpy array and two lists of labels.
141
+ Parameters
142
+ ----------
143
+ data
144
+ A 2D numpy array of shape (M, N).
145
+ row_labels
146
+ A list or array of length M with the labels for the rows.
147
+ col_labels
148
+ A list or array of length N with the labels for the columns.
149
+ ax
150
+ A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
151
+ not provided, use current axes or create a new one. Optional.
152
+ cbar_kw
153
+ A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
154
+ cbarlabel
155
+ The label for the colorbar. Optional.
156
+ **kwargs
157
+ All other arguments are forwarded to `imshow`.
158
+ """
159
+
160
+ if ax is None:
161
+ ax = plt.gca()
162
+
163
+ if cbar_kw is None:
164
+ cbar_kw = {}
165
+
166
+ # Plot the heatmap
167
+ im = ax.imshow(data, **kwargs)
168
+
169
+ # Create colorbar
170
+ cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
171
+ cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
172
+
173
+ # Show all ticks and label them with the respective list entries.
174
+ ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
175
+ ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
176
+
177
+ # Let the horizontal axes labeling appear on top.
178
+ ax.tick_params(top=True, bottom=False,
179
+ labeltop=True, labelbottom=False)
180
+
181
+ # Rotate the tick labels and set their alignment.
182
+ plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
183
+ rotation_mode="anchor")
184
+
185
+ # Turn spines off and create white grid.
186
+ ax.spines[:].set_visible(False)
187
+
188
+ ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
189
+ ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
190
+ ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
191
+ ax.tick_params(which="minor", bottom=False, left=False)
192
+
193
+ return im, cbar
194
+
195
+
196
+ def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
197
+ textcolors=("black", "white"),
198
+ threshold=None, **textkw):
199
+ """
200
+ A function to annotate a heatmap.
201
+ Parameters
202
+ ----------
203
+ im
204
+ The AxesImage to be labeled.
205
+ data
206
+ Data used to annotate. If None, the image's data is used. Optional.
207
+ valfmt
208
+ The format of the annotations inside the heatmap. This should either
209
+ use the string format method, e.g. "$ {x:.2f}", or be a
210
+ `matplotlib.ticker.Formatter`. Optional.
211
+ textcolors
212
+ A pair of colors. The first is used for values below a threshold,
213
+ the second for those above. Optional.
214
+ threshold
215
+ Value in data units according to which the colors from textcolors are
216
+ applied. If None (the default) uses the middle of the colormap as
217
+ separation. Optional.
218
+ **kwargs
219
+ All other arguments are forwarded to each call to `text` used to create
220
+ the text labels.
221
+ """
222
+
223
+ if not isinstance(data, (list, np.ndarray)):
224
+ data = im.get_array()
225
+
226
+ # Normalize the threshold to the images color range.
227
+ if threshold is not None:
228
+ threshold = im.norm(threshold)
229
+ else:
230
+ threshold = im.norm(data.max())/2.
231
+
232
+ # Set default alignment to center, but allow it to be
233
+ # overwritten by textkw.
234
+ kw = dict(horizontalalignment="center",
235
+ verticalalignment="center")
236
+ kw.update(textkw)
237
+
238
+ # Get the formatter in case a string is supplied
239
+ if isinstance(valfmt, str):
240
+ valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
241
+
242
+ # Loop over the data and create a `Text` for each "pixel".
243
+ # Change the text's color depending on the data.
244
+ texts = []
245
+ for i in range(data.shape[0]):
246
+ for j in range(data.shape[1]):
247
+ kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
248
+ text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
249
+ texts.append(text)
250
+
251
+ return texts
252
+
253
+ def wens_method_heatmap(df, virus_species):
254
+ # Create a dataframe to store the similarity values
255
+ similarity_df = pd.DataFrame(index=virus_species, columns=virus_species)
256
+ # Fill the dataframe with similarity values
257
+ for virus1 in virus_species:
258
+ for virus2 in virus_species:
259
+ if virus1 == virus2:
260
+ sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
261
+ sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[1]
262
+ similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
263
+ similarity_df.loc[virus1, virus2] = similarity
264
+ else:
265
+ sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
266
+ sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[0]
267
+ similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
268
+ similarity_df.loc[virus1, virus2] = similarity
269
+ similarity_df = similarity_df.apply(pd.to_numeric)
270
+
271
+ # Optional: Handle NaN values if your similarity computation might result in them
272
+ # similarity_df.fillna(0, inplace=True)
273
+
274
+ fig, ax = plt.subplots()
275
+ # Plotting
276
+ im = ax.imshow(similarity_df, cmap="YlGn")
277
+ ax.set_xticks(np.arange(len(virus_species)), labels=virus_species)
278
+ ax.set_yticks(np.arange(len(virus_species)), labels=virus_species)
279
+ plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
280
+ cbar = ax.figure.colorbar(im, ax=ax)
281
+ cbar.ax.set_ylabel("Similarity", rotation=-90, va="bottom")
282
+
283
+
284
+ return fig
285
+
286
+
287
+ ############################################################# ColorSquare ########################################################
288
+ import math
289
+ import numpy as np
290
+ import matplotlib.pyplot as plt
291
+ from matplotlib.colors import ListedColormap
292
+ import pandas as pd
293
+
294
+ def _fill_spiral(matrix, seq_colors, k):
295
+ left, top, right, bottom = 0, 0, k-1, k-1
296
+ index = 0
297
+ while left <= right and top <= bottom:
298
+ for i in range(left, right + 1): # Top row
299
+ if index < len(seq_colors):
300
+ matrix[top][i] = seq_colors[index]
301
+ index += 1
302
+ top += 1
303
+ for i in range(top, bottom + 1): # Right column
304
+ if index < len(seq_colors):
305
+ matrix[i][right] = seq_colors[index]
306
+ index += 1
307
+ right -= 1
308
+ for i in range(right, left - 1, -1): # Bottom row
309
+ if index < len(seq_colors):
310
+ matrix[bottom][i] = seq_colors[index]
311
+ index += 1
312
+ bottom -= 1
313
+ for i in range(bottom, top - 1, -1): # Left column
314
+ if index < len(seq_colors):
315
+ matrix[i][left] = seq_colors[index]
316
+ index += 1
317
+ left += 1
318
+
319
+
320
+ def _generate_color_square(sequence,virus, save=False, count=0, label=None):
321
+ # Define the sequence and corresponding colors with indices
322
+ colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4} # Assign indices to each color
323
+ seq_colors = [colors[char] for char in sequence.lower()] # Map the sequence to color indices
324
+
325
+ # Calculate k (size of the square)
326
+ k = math.ceil(math.sqrt(len(sequence)))
327
+
328
+ # Initialize a k x k matrix filled with the index for 'white'
329
+ matrix = np.full((k, k), colors['n'], dtype=int)
330
+
331
+ # Fill the matrix in a clockwise spiral
332
+ _fill_spiral(matrix, seq_colors, k)
333
+
334
+ # Define a custom color map for plotting
335
+ cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
336
+
337
+ # Plot the matrix
338
+ plt.figure(figsize=(5, 5))
339
+ plt.imshow(matrix, cmap=cmap, interpolation='nearest')
340
+ if label:
341
+ plt.title(label)
342
+ plt.axis('off') # Hide the axes
343
+ if save:
344
+ plt.savefig(f'color_square_{virus}_{count}.png', dpi=300, bbox_inches='tight')
345
+ # plt.show()
346
+
347
+ def plot_color_square(df, virus_species):
348
+ ncols = 3
349
+ nrows = len(virus_species)
350
+ fig, axeses = plt.subplots(
351
+ nrows=nrows,
352
+ ncols=ncols,
353
+ squeeze=False,
354
+ )
355
+ for i in range(0, ncols * nrows):
356
+ row = i // ncols
357
+ col = i % ncols
358
+ axes = axeses[row, col]
359
+ data = df[i]
360
+ virus = virus_species[row]
361
+ # Define the sequence and corresponding colors with indices
362
+ colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
363
+ # remove all non-nucleotide characters
364
+ data = ''.join([char for char in data.lower() if char in 'atcgn'])
365
+ # Assign indices to each color
366
+ seq_colors = [colors[char] for char in data.lower()] # Map the sequence to color indices
367
+
368
+ # Calculate k (size of the square)
369
+ k = math.ceil(math.sqrt(len(data)))
370
+
371
+ # Initialize a k x k matrix filled with the index for 'white'
372
+ matrix = np.full((k, k), colors['n'], dtype=int)
373
+
374
+ # Fill the matrix in a clockwise spiral
375
+ _fill_spiral(matrix, seq_colors, k)
376
+
377
+ # Define a custom color map for plotting
378
+ cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
379
+ axes.imshow(matrix, cmap=cmap, interpolation='nearest')
380
+ axes.set_title(virus)
381
+ return fig
382
+
383
+
384
+
385
+ def generate_color_square(sequence,virus, multi=False, save=False, label=None):
386
+ if multi:
387
+ for i,seq in enumerate(sequence):
388
+ _generate_color_square(seq, virus,save, i, label[i] if label else None)
389
+ else:
390
+ _generate_color_square(sequence, save, label=label)
391
+
392
+
393
+ ############################################################# FCGR ########################################################
394
+
395
+ from typing import Dict, Optional
396
+ from collections import namedtuple
397
+
398
+ # coordinates for x+iy
399
+ Coord = namedtuple("Coord", ["x","y"])
400
+
401
+ # coordinates for a CGR encoding
402
+ CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
403
+
404
+ # coordinates for each nucleotide in the 2d-plane
405
+ DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
406
+
407
+ class CGR:
408
+ "Chaos Game Representation for DNA"
409
+ def __init__(self, coords: Optional[Dict[chr,tuple]]=None):
410
+ self.nucleotide_coords = DEFAULT_COORDS if coords is None else coords
411
+ self.cgr_coords = CGRCoords(0,0,0)
412
+
413
+ def nucleotide_by_coords(self,x,y):
414
+ "Get nucleotide by coordinates (x,y)"
415
+ # filter nucleotide by coordinates
416
+ filtered = dict(filter(lambda item: item[1] == Coord(x,y), self.nucleotide_coords.items()))
417
+
418
+ return list(filtered.keys())[0]
419
+
420
+ def forward(self, nucleotide: str):
421
+ "Compute next CGR coordinates"
422
+ x = (self.cgr_coords.x + self.nucleotide_coords.get(nucleotide).x)/2
423
+ y = (self.cgr_coords.y + self.nucleotide_coords.get(nucleotide).y)/2
424
+
425
+ # update cgr_coords
426
+ self.cgr_coords = CGRCoords(self.cgr_coords.N+1,x,y)
427
+
428
+ def backward(self,):
429
+ "Compute last CGR coordinates. Current nucleotide can be inferred from (x,y)"
430
+ # get current nucleotide based on coordinates
431
+ n_x,n_y = self.coords_current_nucleotide()
432
+ nucleotide = self.nucleotide_by_coords(n_x,n_y)
433
+
434
+ # update coordinates to the previous one
435
+ x = 2*self.cgr_coords.x - n_x
436
+ y = 2*self.cgr_coords.y - n_y
437
+
438
+ # update cgr_coords
439
+ self.cgr_coords = CGRCoords(self.cgr_coords.N-1,x,y)
440
+
441
+ return nucleotide
442
+
443
+ def coords_current_nucleotide(self,):
444
+ x = 1 if self.cgr_coords.x>0 else -1
445
+ y = 1 if self.cgr_coords.y>0 else -1
446
+ return x,y
447
+
448
+ def encode(self, sequence: str):
449
+ "From DNA sequence to CGR"
450
+ # reset starting position to (0,0,0)
451
+ self.reset_coords()
452
+ for nucleotide in sequence:
453
+ self.forward(nucleotide)
454
+ return self.cgr_coords
455
+
456
+ def reset_coords(self,):
457
+ self.cgr_coords = CGRCoords(0,0,0)
458
+
459
+ def decode(self, N:int, x:int, y:int)->str:
460
+ "From CGR to DNA sequence"
461
+ self.cgr_coords = CGRCoords(N,x,y)
462
+
463
+ # decoded sequence
464
+ sequence = []
465
+
466
+ # Recover the entire genome
467
+ while self.cgr_coords.N>0:
468
+ nucleotide = self.backward()
469
+ sequence.append(nucleotide)
470
+ return "".join(sequence[::-1])
471
+
472
+
473
+ from itertools import product
474
+ from collections import defaultdict
475
+ import numpy as np
476
+
477
+ class FCGR(CGR):
478
+ """Frequency matrix CGR
479
+ an (2**k x 2**k) 2D representation will be created for a
480
+ n-long sequence.
481
+ - k represents the k-mer.
482
+ - 2**k x 2**k = 4**k the total number of k-mers (sequences of length k)
483
+ - pixel value correspond to the value of the frequency for each k-mer
484
+ """
485
+
486
+ def __init__(self, k: int,):
487
+ super().__init__()
488
+ self.k = k # k-mer representation
489
+ self.kmers = list("".join(kmer) for kmer in product("ACGT", repeat=self.k))
490
+ self.kmer2pixel = self.kmer2pixel_position()
491
+
492
+ def __call__(self, sequence: str):
493
+ "Given a DNA sequence, returns an array with his frequencies in the same order as FCGR"
494
+ self.count_kmers(sequence)
495
+
496
+ # Create an empty array to save the FCGR values
497
+ array_size = int(2**self.k)
498
+ freq_matrix = np.zeros((array_size,array_size))
499
+
500
+ # Assign frequency to each box in the matrix
501
+ for kmer, freq in self.freq_kmer.items():
502
+ pos_x, pos_y = self.kmer2pixel[kmer]
503
+ freq_matrix[int(pos_x)-1,int(pos_y)-1] = freq
504
+ return freq_matrix
505
+
506
+ def count_kmer(self, kmer):
507
+ if "N" not in kmer:
508
+ self.freq_kmer[kmer] += 1
509
+
510
+ def count_kmers(self, sequence: str):
511
+ self.freq_kmer = defaultdict(int)
512
+ # representativity of kmers
513
+ last_j = len(sequence) - self.k + 1
514
+ kmers = (sequence[i:(i+self.k)] for i in range(last_j))
515
+ # count kmers in a dictionary
516
+ list(self.count_kmer(kmer) for kmer in kmers)
517
+
518
+ def kmer_probabilities(self, sequence: str):
519
+ self.probabilities = defaultdict(float)
520
+ N=len(sequence)
521
+ for key, value in self.freq_kmer.items():
522
+ self.probabilities[key] = float(value) / (N - self.k + 1)
523
+
524
+ def pixel_position(self, kmer: str):
525
+ "Get pixel position in the FCGR matrix for a k-mer"
526
+
527
+ coords = self.encode(kmer)
528
+ N,x,y = coords.N, coords.x, coords.y
529
+
530
+ # Coordinates from [-1,1]² to [1,2**k]²
531
+ np_coords = np.array([(x + 1)/2, (y + 1)/2]) # move coordinates from [-1,1]² to [0,1]²
532
+ np_coords *= 2**self.k # rescale coordinates from [0,1]² to [0,2**k]²
533
+ x,y = np.ceil(np_coords) # round to upper integer
534
+
535
+ # Turn coordinates (cx,cy) into pixel (px,py) position
536
+ # px = 2**k-cy+1, py = cx
537
+ return 2**self.k-int(y)+1, int(x)
538
+
539
+ def kmer2pixel_position(self,):
540
+ kmer2pixel = dict()
541
+ for kmer in self.kmers:
542
+ kmer2pixel[kmer] = self.pixel_position(kmer)
543
+ return kmer2pixel
544
+
545
+
546
+ from tqdm import tqdm
547
+ from pathlib import Path
548
+
549
+ import numpy as np
550
+
551
+
552
+ class GenerateFCGR:
553
+ def __init__(self, kmer: int = 5, ):
554
+ self.kmer = kmer
555
+ self.fcgr = FCGR(kmer)
556
+ self.counter = 0 # count number of time a sequence is converted to fcgr
557
+
558
+
559
+ def __call__(self, list_fasta,):
560
+
561
+ for fasta in tqdm(list_fasta, desc="Generating FCGR"):
562
+ self.from_fasta(fasta)
563
+
564
+
565
+
566
+
567
+ def from_seq(self, seq: str):
568
+ "Get FCGR from a sequence"
569
+ seq = self.preprocessing(seq)
570
+ chaos = self.fcgr(seq)
571
+ self.counter +=1
572
+ return chaos
573
+
574
+ def reset_counter(self,):
575
+ self.counter=0
576
+
577
+ @staticmethod
578
+ def preprocessing(seq):
579
+ seq = seq.upper()
580
+ for letter in seq:
581
+ if letter not in "ATCG":
582
+ seq = seq.replace(letter,"N")
583
+ return seq
584
+
585
+ def plot_fcgr(df, virus_species):
586
+ ncols = 3
587
+ nrows = len(virus_species)
588
+ fig, axeses = plt.subplots(
589
+ nrows=nrows,
590
+ ncols=ncols,
591
+ squeeze=False,
592
+ )
593
+ for i in range(0, ncols * nrows):
594
+ row = i // ncols
595
+ col = i % ncols
596
+ axes = axeses[row, col]
597
+ data = df[i].upper()
598
+ chaos = GenerateFCGR().from_seq(seq=data)
599
+ virus = virus_species[row]
600
+ axes.imshow(chaos)
601
+ axes.set_title(virus)
602
+ return fig
603
+
604
+ ############################################################# Persistant Homology ########################################################
605
+ import numpy as np
606
+ import persim
607
+ import ripser
608
+ import matplotlib.pyplot as plt
609
+
610
+ NUCLEOTIDE_MAPPING = {
611
+ 'a': np.array([1, 0, 0, 0]),
612
+ 'c': np.array([0, 1, 0, 0]),
613
+ 'g': np.array([0, 0, 1, 0]),
614
+ 't': np.array([0, 0, 0, 1])
615
+ }
616
+
617
+ def encode_nucleotide_to_vector(nucleotide):
618
+ return NUCLEOTIDE_MAPPING.get(nucleotide)
619
+
620
+ def chaos_4d_representation(dna_sequence):
621
+ points = [encode_nucleotide_to_vector(dna_sequence[0])]
622
+ for nucleotide in dna_sequence[1:]:
623
+ vector = encode_nucleotide_to_vector(nucleotide)
624
+ if vector is None:
625
+ continue
626
+ next_point = 0.5 * (points[-1] + vector)
627
+ points.append(next_point)
628
+ return np.array(points)
629
+
630
+ def persistence_homology(dna_sequence, multi=False, plot=False, sample_rate=7):
631
+ if multi:
632
+ c4dr_points = np.array([chaos_4d_representation(sequence) for sequence in dna_sequence])
633
+ dgm_dna = [ripser.ripser(points[::sample_rate], maxdim=1)['dgms'] for points in c4dr_points]
634
+ if plot:
635
+ persim.plot_diagrams([dgm[1] for dgm in dgm_dna], labels=[f'sequence {i}' for i in range(len(dna_sequence))])
636
+ else:
637
+ c4dr_points = chaos_4d_representation(dna_sequence)
638
+ dgm_dna = ripser.ripser(c4dr_points[::sample_rate], maxdim=1)['dgms']
639
+ if plot:
640
+ persim.plot_diagrams(dgm_dna[1])
641
+ return dgm_dna
642
+
643
+ def plot_diagrams(
644
+ diagrams,
645
+ plot_only=None,
646
+ title=None,
647
+ xy_range=None,
648
+ labels=None,
649
+ colormap="default",
650
+ size=20,
651
+ ax_color=np.array([0.0, 0.0, 0.0]),
652
+ diagonal=True,
653
+ lifetime=False,
654
+ legend=True,
655
+ show=False,
656
+ ax=None
657
+ ):
658
+ """A helper function to plot persistence diagrams.
659
+ Parameters
660
+ ----------
661
+ diagrams: ndarray (n_pairs, 2) or list of diagrams
662
+ A diagram or list of diagrams. If diagram is a list of diagrams,
663
+ then plot all on the same plot using different colors.
664
+ plot_only: list of numeric
665
+ If specified, an array of only the diagrams that should be plotted.
666
+ title: string, default is None
667
+ If title is defined, add it as title of the plot.
668
+ xy_range: list of numeric [xmin, xmax, ymin, ymax]
669
+ User provided range of axes. This is useful for comparing
670
+ multiple persistence diagrams.
671
+ labels: string or list of strings
672
+ Legend labels for each diagram.
673
+ If none are specified, we use H_0, H_1, H_2,... by default.
674
+ colormap: string, default is 'default'
675
+ Any of matplotlib color palettes.
676
+ Some options are 'default', 'seaborn', 'sequential'.
677
+ See all available styles with
678
+ .. code:: python
679
+ import matplotlib as mpl
680
+ print(mpl.styles.available)
681
+ size: numeric, default is 20
682
+ Pixel size of each point plotted.
683
+ ax_color: any valid matplotlib color type.
684
+ See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
685
+ diagonal: bool, default is True
686
+ Plot the diagonal x=y line.
687
+ lifetime: bool, default is False. If True, diagonal is turned to False.
688
+ Plot life time of each point instead of birth and death.
689
+ Essentially, visualize (x, y-x).
690
+ legend: bool, default is True
691
+ If true, show the legend.
692
+ show: bool, default is False
693
+ Call plt.show() after plotting. If you are using self.plot() as part
694
+ of a subplot, set show=False and call plt.show() only once at the end.
695
+ """
696
+
697
+ fig, ax = plt.subplots() if ax is None else ax
698
+ plt.style.use(colormap)
699
+
700
+ xlabel, ylabel = "Birth", "Death"
701
+
702
+ if not isinstance(diagrams, list):
703
+ # Must have diagrams as a list for processing downstream
704
+ diagrams = [diagrams]
705
+
706
+ if labels is None:
707
+ # Provide default labels for diagrams if using self.dgm_
708
+ labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)]
709
+
710
+ if plot_only:
711
+ diagrams = [diagrams[i] for i in plot_only]
712
+ labels = [labels[i] for i in plot_only]
713
+
714
+ if not isinstance(labels, list):
715
+ labels = [labels] * len(diagrams)
716
+
717
+ # Construct copy with proper type of each diagram
718
+ # so we can freely edit them.
719
+ diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
720
+
721
+ # find min and max of all visible diagrams
722
+ concat_dgms = np.concatenate(diagrams).flatten()
723
+ has_inf = np.any(np.isinf(concat_dgms))
724
+ finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
725
+
726
+ # clever bounding boxes of the diagram
727
+ if not xy_range:
728
+ # define bounds of diagram
729
+ ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
730
+ x_r = ax_max - ax_min
731
+
732
+ # Give plot a nice buffer on all sides.
733
+ # ax_range=0 when only one point,
734
+ buffer = 1 if xy_range == 0 else x_r / 5
735
+
736
+ x_down = ax_min - buffer / 2
737
+ x_up = ax_max + buffer
738
+
739
+ y_down, y_up = x_down, x_up
740
+ else:
741
+ x_down, x_up, y_down, y_up = xy_range
742
+
743
+ yr = y_up - y_down
744
+
745
+ if lifetime:
746
+
747
+ # Don't plot landscape and diagonal at the same time.
748
+ diagonal = False
749
+
750
+ # reset y axis so it doesn't go much below zero
751
+ y_down = -yr * 0.05
752
+ y_up = y_down + yr
753
+
754
+ # set custom ylabel
755
+ ylabel = "Lifetime"
756
+
757
+ # set diagrams to be (x, y-x)
758
+ for dgm in diagrams:
759
+ dgm[:, 1] -= dgm[:, 0]
760
+
761
+ # plot horizon line
762
+ ax.plot([x_down, x_up], [0, 0], c=ax_color)
763
+
764
+ # Plot diagonal
765
+ if diagonal:
766
+ ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
767
+
768
+ # Plot inf line
769
+ if has_inf:
770
+ # put inf line slightly below top
771
+ b_inf = y_down + yr * 0.95
772
+ ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
773
+
774
+ # convert each inf in each diagram with b_inf
775
+ for dgm in diagrams:
776
+ dgm[np.isinf(dgm)] = b_inf
777
+
778
+ # Plot each diagram
779
+ for dgm, label in zip(diagrams, labels):
780
+
781
+ # plot persistence pairs
782
+ ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none")
783
+
784
+ ax.set_xlabel(xlabel)
785
+ ax.set_ylabel(ylabel)
786
+
787
+ ax.set_xlim([x_down, x_up])
788
+ ax.set_ylim([y_down, y_up])
789
+ ax.set_aspect('equal', 'box')
790
+
791
+ if title is not None:
792
+ ax.set_title(title)
793
+
794
+ if legend is True:
795
+ ax.legend(loc="lower right")
796
+
797
+ if show is True:
798
+ plt.show()
799
+ return fig, ax
800
+
801
+
802
+ def plot_persistence_homology(df, virus_species):
803
+ # if len(virus_species.unique()) > 1:
804
+ c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
805
+ dgm_dna = [ripser.ripser(points[::15], maxdim=1)['dgms'] for points in c4dr_points]
806
+ labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
807
+ fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
808
+ # else:
809
+ # c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
810
+ # dgm_dna = [ripser.ripser(points[::10], maxdim=1)['dgms'] for points in c4dr_points]
811
+ # labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
812
+ # print(labels)
813
+ # print(len(dgm_dna))
814
+ # fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
815
+ return fig
816
+
817
+ def compare_persistence_homology(dna_sequence1, dna_sequence2):
818
+ dgm_dna1 = persistence_homology(dna_sequence1)
819
+ dgm_dna2 = persistence_homology(dna_sequence2)
820
+ distance = persim.sliced_wasserstein(dgm_dna1[1], dgm_dna2[1])
821
+ return distance
822
+
823
+