Hack90 commited on
Commit
3590429
·
verified ·
1 Parent(s): 6009f69

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -1132
app.py CHANGED
@@ -1,86 +1,16 @@
1
- from shiny import render
2
- from shiny.express import input, ui
3
- from datasets import load_dataset
4
  import pandas as pd
5
- from pathlib import Path
6
- import matplotlib
7
- import numpy as np
8
- import gradio as gr
9
- from shiny.express import input, output, render, ui
10
-
11
- ############################################################# 2D Line Plot ########################################################
12
- ### dvq stuff, obvs this will just be an import in the final version
13
- from typing import Dict, Optional
14
- from collections import namedtuple
15
- import numpy as np
16
- import matplotlib.pyplot as plt
17
- import matplotlib.style as mplstyle
18
- from pathlib import Path
19
- from shiny import render
20
- from shiny.express import input, ui
21
- import pandas as pd
22
- from pathlib import Path
23
  import matplotlib.pyplot as plt
24
- import numpy as np
25
- import pandas as pd
26
  from scipy.interpolate import interp1d
27
- import numpy as np
28
-
29
-
30
- # Mapping of nucleotides to float coordinates
31
- mapping_easy = {
32
- 'A': np.array([0.5, -0.8660254037844386]),
33
- 'T': np.array([0.5, 0.8660254037844386]),
34
- 'G': np.array([0.8660254037844386, -0.5]),
35
- 'C': np.array([0.8660254037844386, 0.5]),
36
- 'N': np.array([0, 0])
37
- }
38
-
39
- # coordinates for x+iy
40
- Coord = namedtuple("Coord", ["x","y"])
41
-
42
- # coordinates for a CGR encoding
43
- CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
44
 
45
- # coordinates for each nucleotide in the 2d-plane
46
- DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
47
 
48
- # Function to convert a DNA sequence to a list of coordinates
49
- def _dna_to_coordinates(dna_sequence, mapping):
50
- dna_sequence = dna_sequence.upper()
51
- coordinates = np.array([mapping.get(nucleotide, mapping['N']) for nucleotide in dna_sequence])
52
- return coordinates
53
-
54
- # Function to create the cumulative sum of a list of coordinates
55
- def _get_cumulative_coords(mapped_coords):
56
- cumulative_coords = np.cumsum(mapped_coords, axis=0)
57
- return cumulative_coords
58
-
59
- # Function to take a list of DNA sequences and plot them in a single figure
60
- def plot_2d_sequences(dna_sequences, mapping=mapping_easy, single_sequence=False):
61
- fig, ax = plt.subplots()
62
- if single_sequence:
63
- dna_sequences = [dna_sequences]
64
- for dna_sequence in dna_sequences:
65
- mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
66
- cumulative_coords = _get_cumulative_coords(mapped_coords)
67
- ax.plot(*cumulative_coords.T)
68
- return fig
69
-
70
- # Function to plot a comparison of DNA sequences
71
- def plot_2d_comparison(dna_sequences_grouped, labels, mapping=mapping_easy):
72
- fig, ax = plt.subplots()
73
- colors = plt.cm.rainbow(np.linspace(0, 1, len(dna_sequences_grouped)))
74
- for count, (dna_sequences, color) in enumerate(zip(dna_sequences_grouped, colors)):
75
- for dna_sequence in dna_sequences:
76
- mapped_coords = _dna_to_coordinates(dna_sequence, mapping)
77
- cumulative_coords = _get_cumulative_coords(mapped_coords)
78
- ax.plot(*cumulative_coords.T, color=color, label=labels[count])
79
- # Only show unique labels in the legend
80
- handles, labels = ax.get_legend_handles_labels()
81
- by_label = dict(zip(labels, handles))
82
- ax.legend(by_label.values(), by_label.keys())
83
- return fig
84
 
85
 
86
  ############################################################# Virus Dataset ########################################################
@@ -94,1140 +24,224 @@ def filter_and_select(group):
94
  if len(group) >= 3:
95
  return group.head(3)
96
 
97
- ############################################################# Wens Method ########################################################
98
- import numpy as np
99
-
100
- WEIGHTS = {'0100': 1/6, '0101': 2/6, '1100' : 3/6, '0110':3/6, '1101': 4/6, '1110': 5/6,'0111':5/6, '1111': 6/6}
101
- LOWEST_LENGTH = 5000
102
-
103
- def _get_subsequences(sequence):
104
- return {nuc: [i+1 for i, x in enumerate(sequence) if x == nuc] for nuc in 'ACTG'}
105
-
106
- def _calculate_coordinates_fixed(subsequence, L=LOWEST_LENGTH):
107
- return [((2 * np.pi / (L - 1)) * (K-1), np.sqrt((2 * np.pi / (L - 1)) * (K-1))) for K in subsequence]
108
-
109
- def _calculate_weighting_full(sequence, WEIGHTS, L=LOWEST_LENGTH, E=0.0375):
110
- weightings = [0]
111
- for i in range(1, len(sequence) - 1):
112
- if i < len(sequence) - 2:
113
- subsequence = sequence[i-1:i+3]
114
- comparison_pattern = f"{'1' if subsequence[0] == subsequence[1] else '0'}1{'1' if subsequence[2] == subsequence[1] else '0'}{'1' if subsequence[3] == subsequence[1] else '0'}"
115
- weight = WEIGHTS.get(comparison_pattern, 0)
116
- weight = weight * E if i > L else weight
117
- else:
118
- weight = 0
119
- weightings.append(weight)
120
- weightings.append(0)
121
- return weightings
122
-
123
- def _centre_of_mass(polar_coordinates, weightings):
124
- x, y = _calculate_standard_coordinates(polar_coordinates)
125
- return sum(weightings[i] * ((x[i] - (x[i]*weightings[i]))**2 + (y[i] - y[i]*weightings[i])**2) for i in range(len(x)))
126
-
127
- def _normalised_moment_of_inertia(polar_coordinates, weightings):
128
- moment = _centre_of_mass(polar_coordinates, weightings)
129
- return np.sqrt(moment / sum(weightings))
130
-
131
- def _calculate_standard_coordinates(polar_coordinates):
132
- return [rho * np.cos(theta) for theta, rho in polar_coordinates], [rho * np.sin(theta) for theta, rho in polar_coordinates]
133
-
134
-
135
- def _moments_of_inertia(polar_coordinates, weightings):
136
- return [_normalised_moment_of_inertia(indices, weightings) for subsequence, indices in polar_coordinates.items()]
137
-
138
- def moment_of_inertia(sequence, WEIGHTS, L=5000, E=0.0375):
139
- subsequences = _get_subsequences(sequence)
140
- polar_coordinates = {subsequence: _calculate_coordinates_fixed(indices, len(sequence)) for subsequence, indices in subsequences.items()}
141
- weightings = _calculate_weighting_full(sequence, WEIGHTS, L=L, E=E)
142
- return _moments_of_inertia(polar_coordinates, weightings)
143
-
144
-
145
- def similarity_wen(sequence1, sequence2, WEIGHTS, L=5000, E=0.0375):
146
- L = min(len(sequence1), len(sequence2))
147
- inertia1 = moment_of_inertia(sequence1, WEIGHTS, L=L, E=E)
148
- inertia2 = moment_of_inertia(sequence2, WEIGHTS, L=L, E=E)
149
- similarity = np.sqrt(sum((x - y)**2 for x, y in zip(inertia1, inertia2)))
150
- return similarity
151
- def heatmap(data, row_labels, col_labels, ax=None,
152
- cbar_kw=None, cbarlabel="", **kwargs):
153
- """
154
- Create a heatmap from a numpy array and two lists of labels.
155
-
156
- Parameters
157
- ----------
158
- data
159
- A 2D numpy array of shape (M, N).
160
- row_labels
161
- A list or array of length M with the labels for the rows.
162
- col_labels
163
- A list or array of length N with the labels for the columns.
164
- ax
165
- A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
166
- not provided, use current axes or create a new one. Optional.
167
- cbar_kw
168
- A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional.
169
- cbarlabel
170
- The label for the colorbar. Optional.
171
- **kwargs
172
- All other arguments are forwarded to `imshow`.
173
- """
174
-
175
- if ax is None:
176
- ax = plt.gca()
177
-
178
- if cbar_kw is None:
179
- cbar_kw = {}
180
-
181
- # Plot the heatmap
182
- im = ax.imshow(data, **kwargs)
183
-
184
- # Create colorbar
185
- cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
186
- cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom")
187
-
188
- # Show all ticks and label them with the respective list entries.
189
- ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
190
- ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
191
-
192
- # Let the horizontal axes labeling appear on top.
193
- ax.tick_params(top=True, bottom=False,
194
- labeltop=True, labelbottom=False)
195
-
196
- # Rotate the tick labels and set their alignment.
197
- plt.setp(ax.get_xticklabels(), rotation=-30, ha="right",
198
- rotation_mode="anchor")
199
-
200
- # Turn spines off and create white grid.
201
- ax.spines[:].set_visible(False)
202
-
203
- ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True)
204
- ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True)
205
- ax.grid(which="minor", color="w", linestyle='-', linewidth=3)
206
- ax.tick_params(which="minor", bottom=False, left=False)
207
-
208
- return im, cbar
209
-
210
-
211
- def annotate_heatmap(im, data=None, valfmt="{x:.2f}",
212
- textcolors=("black", "white"),
213
- threshold=None, **textkw):
214
- """
215
- A function to annotate a heatmap.
216
-
217
- Parameters
218
- ----------
219
- im
220
- The AxesImage to be labeled.
221
- data
222
- Data used to annotate. If None, the image's data is used. Optional.
223
- valfmt
224
- The format of the annotations inside the heatmap. This should either
225
- use the string format method, e.g. "$ {x:.2f}", or be a
226
- `matplotlib.ticker.Formatter`. Optional.
227
- textcolors
228
- A pair of colors. The first is used for values below a threshold,
229
- the second for those above. Optional.
230
- threshold
231
- Value in data units according to which the colors from textcolors are
232
- applied. If None (the default) uses the middle of the colormap as
233
- separation. Optional.
234
- **kwargs
235
- All other arguments are forwarded to each call to `text` used to create
236
- the text labels.
237
- """
238
-
239
- if not isinstance(data, (list, np.ndarray)):
240
- data = im.get_array()
241
-
242
- # Normalize the threshold to the images color range.
243
- if threshold is not None:
244
- threshold = im.norm(threshold)
245
- else:
246
- threshold = im.norm(data.max())/2.
247
-
248
- # Set default alignment to center, but allow it to be
249
- # overwritten by textkw.
250
- kw = dict(horizontalalignment="center",
251
- verticalalignment="center")
252
- kw.update(textkw)
253
-
254
- # Get the formatter in case a string is supplied
255
- if isinstance(valfmt, str):
256
- valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
257
-
258
- # Loop over the data and create a `Text` for each "pixel".
259
- # Change the text's color depending on the data.
260
- texts = []
261
- for i in range(data.shape[0]):
262
- for j in range(data.shape[1]):
263
- kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)])
264
- text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
265
- texts.append(text)
266
-
267
- return texts
268
-
269
- def wens_method_heatmap(df, virus_species):
270
- # Create a dataframe to store the similarity values
271
- similarity_df = pd.DataFrame(index=virus_species, columns=virus_species)
272
- # Fill the dataframe with similarity values
273
- for virus1 in virus_species:
274
- for virus2 in virus_species:
275
- if virus1 == virus2:
276
- sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
277
- sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[1]
278
- similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
279
- similarity_df.loc[virus1, virus2] = similarity
280
- else:
281
- sequence1 = df[df['Organism_Name'] == virus1]['Sequence'].values[0]
282
- sequence2 = df[df['Organism_Name'] == virus2]['Sequence'].values[0]
283
- similarity = similarity_wen(sequence1, sequence2, WEIGHTS)
284
- similarity_df.loc[virus1, virus2] = similarity
285
- similarity_df = similarity_df.apply(pd.to_numeric)
286
-
287
- # Optional: Handle NaN values if your similarity computation might result in them
288
- # similarity_df.fillna(0, inplace=True)
289
-
290
- fig, ax = plt.subplots()
291
- # Plotting
292
- im = ax.imshow(similarity_df, cmap="YlGn")
293
- ax.set_xticks(np.arange(len(virus_species)), labels=virus_species)
294
- ax.set_yticks(np.arange(len(virus_species)), labels=virus_species)
295
- plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
296
- cbar = ax.figure.colorbar(im, ax=ax)
297
- cbar.ax.set_ylabel("Similarity", rotation=-90, va="bottom")
298
-
299
-
300
- return fig
301
-
302
-
303
- ############################################################# ColorSquare ########################################################
304
- import math
305
- import numpy as np
306
- import matplotlib.pyplot as plt
307
- from matplotlib.colors import ListedColormap
308
- import pandas as pd
309
-
310
- def _fill_spiral(matrix, seq_colors, k):
311
- left, top, right, bottom = 0, 0, k-1, k-1
312
- index = 0
313
- while left <= right and top <= bottom:
314
- for i in range(left, right + 1): # Top row
315
- if index < len(seq_colors):
316
- matrix[top][i] = seq_colors[index]
317
- index += 1
318
- top += 1
319
- for i in range(top, bottom + 1): # Right column
320
- if index < len(seq_colors):
321
- matrix[i][right] = seq_colors[index]
322
- index += 1
323
- right -= 1
324
- for i in range(right, left - 1, -1): # Bottom row
325
- if index < len(seq_colors):
326
- matrix[bottom][i] = seq_colors[index]
327
- index += 1
328
- bottom -= 1
329
- for i in range(bottom, top - 1, -1): # Left column
330
- if index < len(seq_colors):
331
- matrix[i][left] = seq_colors[index]
332
- index += 1
333
- left += 1
334
-
335
-
336
- def _generate_color_square(sequence,virus, save=False, count=0, label=None):
337
- # Define the sequence and corresponding colors with indices
338
- colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4} # Assign indices to each color
339
- seq_colors = [colors[char] for char in sequence.lower()] # Map the sequence to color indices
340
-
341
- # Calculate k (size of the square)
342
- k = math.ceil(math.sqrt(len(sequence)))
343
-
344
- # Initialize a k x k matrix filled with the index for 'white'
345
- matrix = np.full((k, k), colors['n'], dtype=int)
346
-
347
- # Fill the matrix in a clockwise spiral
348
- _fill_spiral(matrix, seq_colors, k)
349
-
350
- # Define a custom color map for plotting
351
- cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
352
-
353
- # Plot the matrix
354
- plt.figure(figsize=(5, 5))
355
- plt.imshow(matrix, cmap=cmap, interpolation='nearest')
356
- if label:
357
- plt.title(label)
358
- plt.axis('off') # Hide the axes
359
- if save:
360
- plt.savefig(f'color_square_{virus}_{count}.png', dpi=300, bbox_inches='tight')
361
- # plt.show()
362
-
363
- def plot_color_square(df, virus_species):
364
- ncols = 3
365
- nrows = len(virus_species)
366
- fig, axeses = plt.subplots(
367
- nrows=nrows,
368
- ncols=ncols,
369
- squeeze=False,
370
- )
371
- for i in range(0, ncols * nrows):
372
- row = i // ncols
373
- col = i % ncols
374
- axes = axeses[row, col]
375
- data = df[i]
376
- virus = virus_species[row]
377
- # Define the sequence and corresponding colors with indices
378
- colors = {'a': 0, 't': 1, 'c': 2, 'g': 3, 'n': 4}
379
- # remove all non-nucleotide characters
380
- data = ''.join([char for char in data.lower() if char in 'atcgn'])
381
- # Assign indices to each color
382
- seq_colors = [colors[char] for char in data.lower()] # Map the sequence to color indices
383
-
384
- # Calculate k (size of the square)
385
- k = math.ceil(math.sqrt(len(data)))
386
-
387
- # Initialize a k x k matrix filled with the index for 'white'
388
- matrix = np.full((k, k), colors['n'], dtype=int)
389
-
390
- # Fill the matrix in a clockwise spiral
391
- _fill_spiral(matrix, seq_colors, k)
392
-
393
- # Define a custom color map for plotting
394
- cmap = ListedColormap(['red', 'green', 'yellow', 'blue', 'white'])
395
- axes.imshow(matrix, cmap=cmap, interpolation='nearest')
396
- axes.set_title(virus)
397
- return fig
398
-
399
-
400
-
401
- def generate_color_square(sequence,virus, multi=False, save=False, label=None):
402
- if multi:
403
- for i,seq in enumerate(sequence):
404
- _generate_color_square(seq, virus,save, i, label[i] if label else None)
405
- else:
406
- _generate_color_square(sequence, save, label=label)
407
-
408
-
409
- ############################################################# FCGR ########################################################
410
-
411
- from typing import Dict, Optional
412
- from collections import namedtuple
413
-
414
- # coordinates for x+iy
415
- Coord = namedtuple("Coord", ["x","y"])
416
-
417
- # coordinates for a CGR encoding
418
- CGRCoords = namedtuple("CGRCoords", ["N","x","y"])
419
-
420
- # coordinates for each nucleotide in the 2d-plane
421
- DEFAULT_COORDS = dict(A=Coord(1,1),C=Coord(-1,1),G=Coord(-1,-1),T=Coord(1,-1))
422
-
423
- class CGR:
424
- "Chaos Game Representation for DNA"
425
- def __init__(self, coords: Optional[Dict[chr,tuple]]=None):
426
- self.nucleotide_coords = DEFAULT_COORDS if coords is None else coords
427
- self.cgr_coords = CGRCoords(0,0,0)
428
-
429
- def nucleotide_by_coords(self,x,y):
430
- "Get nucleotide by coordinates (x,y)"
431
- # filter nucleotide by coordinates
432
- filtered = dict(filter(lambda item: item[1] == Coord(x,y), self.nucleotide_coords.items()))
433
-
434
- return list(filtered.keys())[0]
435
-
436
- def forward(self, nucleotide: str):
437
- "Compute next CGR coordinates"
438
- x = (self.cgr_coords.x + self.nucleotide_coords.get(nucleotide).x)/2
439
- y = (self.cgr_coords.y + self.nucleotide_coords.get(nucleotide).y)/2
440
-
441
- # update cgr_coords
442
- self.cgr_coords = CGRCoords(self.cgr_coords.N+1,x,y)
443
-
444
- def backward(self,):
445
- "Compute last CGR coordinates. Current nucleotide can be inferred from (x,y)"
446
- # get current nucleotide based on coordinates
447
- n_x,n_y = self.coords_current_nucleotide()
448
- nucleotide = self.nucleotide_by_coords(n_x,n_y)
449
-
450
- # update coordinates to the previous one
451
- x = 2*self.cgr_coords.x - n_x
452
- y = 2*self.cgr_coords.y - n_y
453
-
454
- # update cgr_coords
455
- self.cgr_coords = CGRCoords(self.cgr_coords.N-1,x,y)
456
-
457
- return nucleotide
458
-
459
- def coords_current_nucleotide(self,):
460
- x = 1 if self.cgr_coords.x>0 else -1
461
- y = 1 if self.cgr_coords.y>0 else -1
462
- return x,y
463
-
464
- def encode(self, sequence: str):
465
- "From DNA sequence to CGR"
466
- # reset starting position to (0,0,0)
467
- self.reset_coords()
468
- for nucleotide in sequence:
469
- self.forward(nucleotide)
470
- return self.cgr_coords
471
-
472
- def reset_coords(self,):
473
- self.cgr_coords = CGRCoords(0,0,0)
474
-
475
- def decode(self, N:int, x:int, y:int)->str:
476
- "From CGR to DNA sequence"
477
- self.cgr_coords = CGRCoords(N,x,y)
478
-
479
- # decoded sequence
480
- sequence = []
481
-
482
- # Recover the entire genome
483
- while self.cgr_coords.N>0:
484
- nucleotide = self.backward()
485
- sequence.append(nucleotide)
486
- return "".join(sequence[::-1])
487
-
488
-
489
- from itertools import product
490
- from collections import defaultdict
491
- import numpy as np
492
-
493
- class FCGR(CGR):
494
- """Frequency matrix CGR
495
- an (2**k x 2**k) 2D representation will be created for a
496
- n-long sequence.
497
- - k represents the k-mer.
498
- - 2**k x 2**k = 4**k the total number of k-mers (sequences of length k)
499
- - pixel value correspond to the value of the frequency for each k-mer
500
- """
501
-
502
- def __init__(self, k: int,):
503
- super().__init__()
504
- self.k = k # k-mer representation
505
- self.kmers = list("".join(kmer) for kmer in product("ACGT", repeat=self.k))
506
- self.kmer2pixel = self.kmer2pixel_position()
507
-
508
- def __call__(self, sequence: str):
509
- "Given a DNA sequence, returns an array with his frequencies in the same order as FCGR"
510
- self.count_kmers(sequence)
511
-
512
- # Create an empty array to save the FCGR values
513
- array_size = int(2**self.k)
514
- freq_matrix = np.zeros((array_size,array_size))
515
-
516
- # Assign frequency to each box in the matrix
517
- for kmer, freq in self.freq_kmer.items():
518
- pos_x, pos_y = self.kmer2pixel[kmer]
519
- freq_matrix[int(pos_x)-1,int(pos_y)-1] = freq
520
- return freq_matrix
521
-
522
- def count_kmer(self, kmer):
523
- if "N" not in kmer:
524
- self.freq_kmer[kmer] += 1
525
-
526
- def count_kmers(self, sequence: str):
527
- self.freq_kmer = defaultdict(int)
528
- # representativity of kmers
529
- last_j = len(sequence) - self.k + 1
530
- kmers = (sequence[i:(i+self.k)] for i in range(last_j))
531
- # count kmers in a dictionary
532
- list(self.count_kmer(kmer) for kmer in kmers)
533
-
534
- def kmer_probabilities(self, sequence: str):
535
- self.probabilities = defaultdict(float)
536
- N=len(sequence)
537
- for key, value in self.freq_kmer.items():
538
- self.probabilities[key] = float(value) / (N - self.k + 1)
539
-
540
- def pixel_position(self, kmer: str):
541
- "Get pixel position in the FCGR matrix for a k-mer"
542
-
543
- coords = self.encode(kmer)
544
- N,x,y = coords.N, coords.x, coords.y
545
-
546
- # Coordinates from [-1,1]² to [1,2**k]²
547
- np_coords = np.array([(x + 1)/2, (y + 1)/2]) # move coordinates from [-1,1]² to [0,1]²
548
- np_coords *= 2**self.k # rescale coordinates from [0,1]² to [0,2**k]²
549
- x,y = np.ceil(np_coords) # round to upper integer
550
-
551
- # Turn coordinates (cx,cy) into pixel (px,py) position
552
- # px = 2**k-cy+1, py = cx
553
- return 2**self.k-int(y)+1, int(x)
554
-
555
- def kmer2pixel_position(self,):
556
- kmer2pixel = dict()
557
- for kmer in self.kmers:
558
- kmer2pixel[kmer] = self.pixel_position(kmer)
559
- return kmer2pixel
560
-
561
-
562
- from tqdm import tqdm
563
- from pathlib import Path
564
-
565
- import numpy as np
566
-
567
-
568
- class GenerateFCGR:
569
- def __init__(self, kmer: int = 5, ):
570
- self.kmer = kmer
571
- self.fcgr = FCGR(kmer)
572
- self.counter = 0 # count number of time a sequence is converted to fcgr
573
-
574
-
575
- def __call__(self, list_fasta,):
576
-
577
- for fasta in tqdm(list_fasta, desc="Generating FCGR"):
578
- self.from_fasta(fasta)
579
-
580
-
581
-
582
-
583
- def from_seq(self, seq: str):
584
- "Get FCGR from a sequence"
585
- seq = self.preprocessing(seq)
586
- chaos = self.fcgr(seq)
587
- self.counter +=1
588
- return chaos
589
-
590
- def reset_counter(self,):
591
- self.counter=0
592
-
593
- @staticmethod
594
- def preprocessing(seq):
595
- seq = seq.upper()
596
- for letter in seq:
597
- if letter not in "ATCG":
598
- seq = seq.replace(letter,"N")
599
- return seq
600
-
601
- def plot_fcgr(df, virus_species):
602
- ncols = 3
603
- nrows = len(virus_species)
604
- fig, axeses = plt.subplots(
605
- nrows=nrows,
606
- ncols=ncols,
607
- squeeze=False,
608
- )
609
- for i in range(0, ncols * nrows):
610
- row = i // ncols
611
- col = i % ncols
612
- axes = axeses[row, col]
613
- data = df[i].upper()
614
- chaos = GenerateFCGR().from_seq(seq=data)
615
- virus = virus_species[row]
616
- axes.imshow(chaos)
617
- axes.set_title(virus)
618
- return fig
619
-
620
- ############################################################# Persistant Homology ########################################################
621
- import numpy as np
622
- import persim
623
- import ripser
624
- import matplotlib.pyplot as plt
625
-
626
- NUCLEOTIDE_MAPPING = {
627
- 'a': np.array([1, 0, 0, 0]),
628
- 'c': np.array([0, 1, 0, 0]),
629
- 'g': np.array([0, 0, 1, 0]),
630
- 't': np.array([0, 0, 0, 1])
631
- }
632
-
633
- def encode_nucleotide_to_vector(nucleotide):
634
- return NUCLEOTIDE_MAPPING.get(nucleotide)
635
-
636
- def chaos_4d_representation(dna_sequence):
637
- points = [encode_nucleotide_to_vector(dna_sequence[0])]
638
- for nucleotide in dna_sequence[1:]:
639
- vector = encode_nucleotide_to_vector(nucleotide)
640
- if vector is None:
641
- continue
642
- next_point = 0.5 * (points[-1] + vector)
643
- points.append(next_point)
644
- return np.array(points)
645
-
646
- def persistence_homology(dna_sequence, multi=False, plot=False, sample_rate=7):
647
- if multi:
648
- c4dr_points = np.array([chaos_4d_representation(sequence) for sequence in dna_sequence])
649
- dgm_dna = [ripser.ripser(points[::sample_rate], maxdim=1)['dgms'] for points in c4dr_points]
650
- if plot:
651
- persim.plot_diagrams([dgm[1] for dgm in dgm_dna], labels=[f'sequence {i}' for i in range(len(dna_sequence))])
652
- else:
653
- c4dr_points = chaos_4d_representation(dna_sequence)
654
- dgm_dna = ripser.ripser(c4dr_points[::sample_rate], maxdim=1)['dgms']
655
- if plot:
656
- persim.plot_diagrams(dgm_dna[1])
657
- return dgm_dna
658
-
659
- def plot_diagrams(
660
- diagrams,
661
- plot_only=None,
662
- title=None,
663
- xy_range=None,
664
- labels=None,
665
- colormap="default",
666
- size=20,
667
- ax_color=np.array([0.0, 0.0, 0.0]),
668
- diagonal=True,
669
- lifetime=False,
670
- legend=True,
671
- show=False,
672
- ax=None
673
- ):
674
- """A helper function to plot persistence diagrams.
675
-
676
- Parameters
677
- ----------
678
- diagrams: ndarray (n_pairs, 2) or list of diagrams
679
- A diagram or list of diagrams. If diagram is a list of diagrams,
680
- then plot all on the same plot using different colors.
681
- plot_only: list of numeric
682
- If specified, an array of only the diagrams that should be plotted.
683
- title: string, default is None
684
- If title is defined, add it as title of the plot.
685
- xy_range: list of numeric [xmin, xmax, ymin, ymax]
686
- User provided range of axes. This is useful for comparing
687
- multiple persistence diagrams.
688
- labels: string or list of strings
689
- Legend labels for each diagram.
690
- If none are specified, we use H_0, H_1, H_2,... by default.
691
- colormap: string, default is 'default'
692
- Any of matplotlib color palettes.
693
- Some options are 'default', 'seaborn', 'sequential'.
694
- See all available styles with
695
-
696
- .. code:: python
697
-
698
- import matplotlib as mpl
699
- print(mpl.styles.available)
700
-
701
- size: numeric, default is 20
702
- Pixel size of each point plotted.
703
- ax_color: any valid matplotlib color type.
704
- See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
705
- diagonal: bool, default is True
706
- Plot the diagonal x=y line.
707
- lifetime: bool, default is False. If True, diagonal is turned to False.
708
- Plot life time of each point instead of birth and death.
709
- Essentially, visualize (x, y-x).
710
- legend: bool, default is True
711
- If true, show the legend.
712
- show: bool, default is False
713
- Call plt.show() after plotting. If you are using self.plot() as part
714
- of a subplot, set show=False and call plt.show() only once at the end.
715
- """
716
-
717
- fig, ax = plt.subplots() if ax is None else ax
718
- plt.style.use(colormap)
719
-
720
- xlabel, ylabel = "Birth", "Death"
721
-
722
- if not isinstance(diagrams, list):
723
- # Must have diagrams as a list for processing downstream
724
- diagrams = [diagrams]
725
-
726
- if labels is None:
727
- # Provide default labels for diagrams if using self.dgm_
728
- labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)]
729
-
730
- if plot_only:
731
- diagrams = [diagrams[i] for i in plot_only]
732
- labels = [labels[i] for i in plot_only]
733
-
734
- if not isinstance(labels, list):
735
- labels = [labels] * len(diagrams)
736
-
737
- # Construct copy with proper type of each diagram
738
- # so we can freely edit them.
739
- diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
740
-
741
- # find min and max of all visible diagrams
742
- concat_dgms = np.concatenate(diagrams).flatten()
743
- has_inf = np.any(np.isinf(concat_dgms))
744
- finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
745
-
746
- # clever bounding boxes of the diagram
747
- if not xy_range:
748
- # define bounds of diagram
749
- ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
750
- x_r = ax_max - ax_min
751
-
752
- # Give plot a nice buffer on all sides.
753
- # ax_range=0 when only one point,
754
- buffer = 1 if xy_range == 0 else x_r / 5
755
-
756
- x_down = ax_min - buffer / 2
757
- x_up = ax_max + buffer
758
-
759
- y_down, y_up = x_down, x_up
760
- else:
761
- x_down, x_up, y_down, y_up = xy_range
762
-
763
- yr = y_up - y_down
764
-
765
- if lifetime:
766
-
767
- # Don't plot landscape and diagonal at the same time.
768
- diagonal = False
769
-
770
- # reset y axis so it doesn't go much below zero
771
- y_down = -yr * 0.05
772
- y_up = y_down + yr
773
-
774
- # set custom ylabel
775
- ylabel = "Lifetime"
776
-
777
- # set diagrams to be (x, y-x)
778
- for dgm in diagrams:
779
- dgm[:, 1] -= dgm[:, 0]
780
-
781
- # plot horizon line
782
- ax.plot([x_down, x_up], [0, 0], c=ax_color)
783
-
784
- # Plot diagonal
785
- if diagonal:
786
- ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
787
-
788
- # Plot inf line
789
- if has_inf:
790
- # put inf line slightly below top
791
- b_inf = y_down + yr * 0.95
792
- ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
793
-
794
- # convert each inf in each diagram with b_inf
795
- for dgm in diagrams:
796
- dgm[np.isinf(dgm)] = b_inf
797
-
798
- # Plot each diagram
799
- for dgm, label in zip(diagrams, labels):
800
-
801
- # plot persistence pairs
802
- ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none")
803
-
804
- ax.set_xlabel(xlabel)
805
- ax.set_ylabel(ylabel)
806
-
807
- ax.set_xlim([x_down, x_up])
808
- ax.set_ylim([y_down, y_up])
809
- ax.set_aspect('equal', 'box')
810
-
811
- if title is not None:
812
- ax.set_title(title)
813
-
814
- if legend is True:
815
- ax.legend(loc="lower right")
816
-
817
- if show is True:
818
- plt.show()
819
- return fig, ax
820
-
821
-
822
- def plot_persistence_homology(df, virus_species):
823
- # if len(virus_species.unique()) > 1:
824
- c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
825
- dgm_dna = [ripser.ripser(points[::15], maxdim=1)['dgms'] for points in c4dr_points]
826
- labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
827
- fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
828
- # else:
829
- # c4dr_points = [chaos_4d_representation(sequence.lower()) for sequence in df]
830
- # dgm_dna = [ripser.ripser(points[::10], maxdim=1)['dgms'] for points in c4dr_points]
831
- # labels =[f'{virus_specie}_{i}' for i, virus_specie in enumerate(virus_species)]
832
- # print(labels)
833
- # print(len(dgm_dna))
834
- # fig, ax = plot_diagrams([dgm[1] for dgm in dgm_dna], labels=labels)
835
- return fig
836
-
837
- def compare_persistence_homology(dna_sequence1, dna_sequence2):
838
- dgm_dna1 = persistence_homology(dna_sequence1)
839
- dgm_dna2 = persistence_homology(dna_sequence2)
840
- distance = persim.sliced_wasserstein(dgm_dna1[1], dgm_dna2[1])
841
- return distance
842
-
843
  ############################################################# UI #################################################################
844
 
845
  ui.page_opts(fillable=True)
846
 
847
- with ui.navset_card_tab(id="tab"):
848
  with ui.nav_panel("Viral Macrostructure"):
849
- ui.page_opts(fillable=True)
850
  ui.panel_title("Do viruses have underlying structure?")
851
  with ui.layout_columns():
852
  with ui.card():
853
- ui.input_selectize(
854
- "virus_selector",
855
- "Select your viruses:",
856
- virus,
857
- multiple=True, selected=None
858
- )
859
  with ui.card():
860
- ui.input_selectize(
861
- "plot_type_macro",
862
- "Select your method:",
863
- ["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
864
- multiple=False, selected=None
865
- )
866
-
867
- ############################################################# Plotting ########################################################
868
- here = Path(__file__).parent
869
- import matplotlib as mpl
870
- # @output(suspend_when_hidden=True)
871
- @render.plot()
872
- def plot_macro():
873
- #ds = load_dataset('Hack90/virus_tiny')
874
- df = pd.read_parquet('virus_ds.parquet')
875
- df = df[df['Organism_Name'].isin(input.virus_selector())]
876
- # group by virus
877
- grouped = df.groupby('Organism_Name')['Sequence'].apply(list)
878
- mpl.rcParams.update(mpl.rcParamsDefault)
879
-
880
- # plot the comparison
881
- fig = None
882
- if input.plot_type_macro() == "2D Line":
883
- fig = plot_2d_comparison(grouped, grouped.index)
884
- if input.plot_type_macro() == "ColorSquare":
885
- filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
886
- fig = plot_color_square(filtered_df['Sequence'], filtered_df['Organism_Name'].unique())
887
- if input.plot_type_macro() == "Wens Method":
888
- fig = wens_method_heatmap(df, df['Organism_Name'].unique())
889
- if input.plot_type_macro() == "Chaos Game Representation":
890
- filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
891
- fig = plot_fcgr(filtered_df['Sequence'], df['Organism_Name'].unique())
892
- if input.plot_type_macro() == "Persistant Homology":
893
- filtered_df = df.groupby('Organism_Name').apply(filter_and_select).reset_index(drop=True)
894
- fig = plot_persistence_homology(filtered_df['Sequence'], filtered_df['Organism_Name'])
895
- return fig
896
- # ui.output_plot("plot_macro_output")
897
- # with ui.nav_panel("Viral Model"):
898
- # gr.load("models/Hack90/virus_pythia_31_1024").launch()
899
 
900
- with ui.nav_panel("Viral Microstructure"):
901
- ui.page_opts(fillable=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
902
  ui.panel_title("Kmer Distribution")
903
  with ui.layout_columns():
904
  with ui.card():
905
  ui.input_slider("kmer", "kmer", 0, 10, 4)
906
  ui.input_slider("top_k", "top:", 0, 1000, 15)
907
-
908
- ui.input_selectize(
909
- "plot_type",
910
- "Select metric:",
911
- ["percentage", "count"],
912
- multiple=False, selected=None
913
- )
914
-
915
- import matplotlib as mpl
916
- # @output(suspend_when_hidden=True)
917
  @render.plot()
918
- def plot_micro():
919
- df = pd.read_csv('kmers.csv')
920
  k = input.kmer()
921
  top_k = input.top_k()
922
- fig = None
923
- mpl.rcParams.update(mpl.rcParamsDefault)
924
- if input.plot_type() == "count" and input.kmer() > 0:
925
- df = df[df['k'] == k]
926
- df = df.head(top_k)
927
- fig, ax = plt.subplots()
928
- ax.bar(df['kmer'], df['count'])
929
- ax.set_title(f"Most common {k}-mers")
930
- ax.set_xlabel("K-mer")
931
- ax.set_ylabel("Count")
932
- ax.set_xticklabels(df['kmer'], rotation=90)
933
- if input.plot_type() == "percentage" and input.kmer() > 0:
934
- df = df[df['k'] == k]
935
- df = df.head(top_k)
936
  fig, ax = plt.subplots()
937
- ax.bar(df['kmer'], df['percent']*100)
 
 
 
 
 
938
  ax.set_title(f"Most common {k}-mers")
939
  ax.set_xlabel("K-mer")
940
- ax.set_ylabel("Percentage")
941
- ax.set_xticklabels(df['kmer'], rotation=90)
942
- return fig
943
- #ui.output_plot("plot_micro_output")
944
- with ui.nav_panel("Viral Model Training"):
945
- ui.page_opts(fillable=True)
946
  ui.panel_title("Does context size matter for a nucleotide model?")
947
-
948
- def plot_loss_rates(df, type):
949
- # interplot each column to be same number of points
950
  x = np.linspace(0, 1, 1000)
951
  loss_rates = []
952
- labels = ['32', '64', '128', '256', '512', '1024']
953
- #drop the column step
954
- df = df.drop(columns=['Step'])
955
  for col in df.columns:
956
- y = df[col].dropna().astype('float', errors = 'ignore').dropna().values
957
  f = interp1d(np.linspace(0, 1, len(y)), y)
958
  loss_rates.append(f(x))
959
  fig, ax = plt.subplots()
960
  for i, loss_rate in enumerate(loss_rates):
961
  ax.plot(x, loss_rate, label=labels[i])
962
  ax.legend()
963
- ax.set_title(f'Loss rates for a {type} parameter model across context windows')
964
- ax.set_xlabel('Training steps')
965
- ax.set_ylabel('Loss rate')
966
  return fig
967
-
968
- import matplotlib as mpl
969
  @render.image
970
  def plot_context_size_scaling():
971
- fig = None
972
- df = pd.read_csv('14m.csv')
973
- mpl.rcParams.update(mpl.rcParamsDefault)
974
- fig = plot_loss_rates(df, '14M')
975
- import tempfile
976
- fd, path = tempfile.mkstemp(suffix = '.svg')
977
  if fig:
 
 
 
978
  fig.savefig(path)
979
- return {"src": str(path), "width": "600px", "format":"svg"}
980
- return fig
981
  with ui.nav_panel("Model loss analysis"):
982
- ui.page_opts(fillable=True)
983
  ui.panel_title("Neurips stuff")
984
-
985
  with ui.card():
986
  ui.input_selectize(
987
- "param_type",
988
- "Select Param Type:",
989
- ["14", "31", "70", "160", "410"],
990
- multiple=True,
991
- selected=["14", "70"]
992
- )
993
  ui.input_selectize(
994
- "model_type",
995
- "Select Model Type:",
996
- ["pythia", "denseformer", "evo"],
997
- multiple=True,
998
- selected=['pythia','denseformer']
999
- )
1000
  ui.input_selectize(
1001
- "loss_type",
1002
- "Select Loss Type:",
1003
- ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
1004
- multiple=True,
1005
- selected=["compliment", "cross_entropy", "headless"]
1006
- )
1007
- #ui.input_slider("x_filter", "x_filter", 0, 1, 0.01)
1008
  def plot_loss_rates_model(df, param_types, loss_types, model_types):
1009
- # interplot each column to be same number of points
1010
  x = np.linspace(0, 1, 1000)
1011
  loss_rates = []
1012
  labels = []
1013
- print(param_types, loss_types, model_types)
1014
  for param_type in param_types:
1015
  for loss_type in loss_types:
1016
  for model_type in model_types:
1017
- y = df[(df['param_type'] == int(param_type)) & (df['loss_type'] == loss_type) & (df['model_type'] == model_type)]['loss_interp'].values
1018
- print(y)
1019
-
 
 
1020
  if len(y) > 0:
1021
  f = interp1d(np.linspace(0, 1, len(y)), y)
1022
  loss_rates.append(f(x))
1023
- labels.append(str(param_type) + '_' + loss_type + '_' + model_type)
1024
-
1025
  fig, ax = plt.subplots()
1026
- # print(loss_rates)
1027
-
1028
  for i, loss_rate in enumerate(loss_rates):
1029
- # df_madmad = pd.DataFrame({'x':x, 'loss':loss_rate})
1030
-
1031
- # # df_madmad = df_madmad.sort_values(by='x')
1032
- # df_madmad = df_madmad[df_madmad['x']>x_filter]
1033
- # x = df_madmad['x'].to_list()
1034
- # loss_rate = df_madmad['loss'].to_list(
1035
  ax.plot(x, loss_rate, label=labels[i])
1036
-
1037
-
1038
  ax.legend()
1039
- ax.set_xlabel('Training steps')
1040
- ax.set_ylabel('Loss rate')
1041
-
1042
  return fig
1043
-
1044
- import matplotlib as mpl
1045
  @render.image
1046
  def plot_model_scaling():
1047
- fig = None
1048
- df = pd.read_csv('training_data_5.csv')
1049
- df = df[df['epoch_interp']>0.035]
1050
- mpl.rcParams.update(mpl.rcParamsDefault)
1051
- fig = plot_loss_rates_model(df, input.param_type(),input.loss_type(),input.model_type() )
1052
-
1053
- import tempfile
1054
- fd, path = tempfile.mkstemp(suffix = '.svg')
1055
  if fig:
 
 
 
1056
  fig.savefig(path)
1057
- return {"src": str(path), "width": "600px", "format":"svg"}
1058
- return fig
1059
  with ui.nav_panel("Scaling Laws"):
1060
- ui.page_opts(fillable=True)
1061
  ui.panel_title("Params & Losses")
1062
-
1063
  with ui.card():
1064
-
1065
  ui.input_selectize(
1066
- "model_type_scale",
1067
- "Select Model Type:",
1068
- ["pythia", "denseformer", "evo"],
1069
- multiple=True,
1070
- selected=['evo','denseformer']
1071
- )
1072
  ui.input_selectize(
1073
- "loss_type_scale",
1074
- "Select Loss Type:",
1075
- ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
1076
- multiple=True,
1077
- selected=["cross_entropy"]
1078
- )
 
1079
  def plot_loss_rates_model_scale(df, loss_type, model_types):
1080
- df = df[df['loss_type'] == loss_type[0]]
1081
- # interplot each column to be same number of points
1082
  params = []
1083
  loss_rates = []
1084
  labels = []
1085
  for model_type in model_types:
1086
- df_new = df[df['model_type']==model_type]
1087
  losses = []
1088
  params_model = []
1089
- # print(df_new)
1090
- for paramy in df_new['num_params'].unique():
1091
- loss = df_new[df_new['num_params']==paramy]['loss_interp'].min()
1092
- print(loss)
1093
- par = int(paramy)
1094
- print(par)
1095
- losses.append(loss)
1096
- params_model.append(par)
1097
- df_reorder = pd.DataFrame({'loss':losses, 'params':params_model})
1098
- df_reorder = df_reorder.sort_values(by='params')
1099
- print(df_reorder)
1100
- loss_rates.append(df_reorder['loss'].to_list())
1101
- params.append(df_reorder['params'].to_list())
1102
  labels.append(model_type)
1103
-
1104
  fig, ax = plt.subplots()
1105
-
1106
  for i, loss_rate in enumerate(loss_rates):
1107
  ax.plot(params[i], loss_rate, label=labels[i])
1108
-
1109
  ax.legend()
1110
- ax.set_xlabel('Params')
1111
- ax.set_ylabel('Loss')
1112
-
1113
  return fig
1114
-
1115
-
1116
- # import matplotlib as mpl
1117
  @render.image
1118
  def plot_big_boy_model():
1119
- fig = None
1120
- df = pd.read_csv('training_data_5.csv')
1121
- mpl.rcParams.update(mpl.rcParamsDefault)
1122
- fig = plot_loss_rates_model_scale(df,input.loss_type_scale(),input.model_type_scale())
1123
- import tempfile
1124
- fd, path = tempfile.mkstemp(suffix = '.svg')
1125
  if fig:
1126
- fig.savefig(path)
1127
- return {"src": str(path), "width": "600px", "format":"svg"}
1128
- return fig
1129
- # @output
1130
- # @render.plot
1131
- # def plot_training_loss():
1132
- # # if csv_file() is None:
1133
- # # return None
1134
-
1135
- # df = pd.read_csv('results - denseformer.csv')
1136
-
1137
- # filtered_df = df[
1138
- # (df["param_type"].isin(input.param_type()))
1139
- # & (df["model_type"].isin(input.model_type()))
1140
- # & (df["loss_type"].isin(input.loss_type()))
1141
- # ]
1142
-
1143
 
1144
- # if filtered_df.empty:
1145
- # return None
1146
-
1147
- # # Define colors for sizes and shapes for loss types
1148
- # size_colors = {
1149
- # "14": "blue",
1150
- # "31": "green",
1151
- # "70": "orange",
1152
- # "160": "red"
1153
- # }
1154
-
1155
- # loss_markers = {
1156
- # "compliment": "o",
1157
- # "cross_entropy": "^",
1158
- # "headless": "s"
1159
- # }
1160
-
1161
- # # Create the plot
1162
- # fig, ax = plt.subplots(figsize=(10, 6))
1163
-
1164
- # # Plot each combination of size and loss type
1165
- # for size in filtered_df["param_type"].unique():
1166
- # for loss_type in filtered_df["loss_type"].unique():
1167
- # data = filtered_df[(filtered_df["param_type"] == size) & (filtered_df["loss_type"] == loss_type)]
1168
- # ax.plot(data["epoch"], data["loss"], marker=loss_markers[loss_type], color=size_colors[size], label=f"{size} - {loss_type}")
1169
-
1170
- # # Customize the plot
1171
- # ax.set_xlabel("Epoch")
1172
- # ax.set_ylabel("Loss")
1173
- # # ax.set_title("Training Loss by Size and Loss Type", fontsize=16)
1174
-
1175
- # # Create a legend for sizes
1176
- # size_legend = ax.legend(title="Size", loc="upper right")
1177
- # ax.add_artist(size_legend)
1178
-
1179
- # # Create a separate legend for loss types
1180
- # loss_legend_labels = ["Compliment", "Cross Entropy", "Headless"]
1181
- # loss_legend_handles = [plt.Line2D([0], [0], marker=loss_markers[loss_type], color='black', linestyle='None', markersize=8) for loss_type in loss_markers]
1182
- # loss_legend = ax.legend(loss_legend_handles, loss_legend_labels, title="Loss Type", loc="upper right")
1183
-
1184
- # plt.tight_layout()
1185
- # return fig
1186
-
1187
- # # Define colors for sizes and shapes for loss types
1188
- # size_colors = {
1189
- # "14": "blue",
1190
- # "31": "green",
1191
- # "70": "orange",
1192
- # "160": "red"
1193
- # }
1194
- # loss_markers = {
1195
- # "compliment": "o",
1196
- # "cross_entropy": "^",
1197
- # "headless": "s"
1198
- # }
1199
-
1200
- # # Create a relplot using Seaborn
1201
- # g = sns.relplot(
1202
- # data=filtered_df,
1203
- # x="epoch",
1204
- # y="loss",
1205
- # hue="param_type",
1206
- # style="loss_type",
1207
- # palette=size_colors,
1208
- # markers=loss_markers,
1209
- # height=6,
1210
- # aspect=1.5
1211
- # )
1212
-
1213
- # # Customize the plot
1214
- # g.set_xlabels("Epoch")
1215
- # g.set_ylabels("Loss")
1216
- # g.fig.suptitle("Training Loss by Size and Loss Type", fontsize=16)
1217
- # g.add_legend(title="Size")
1218
-
1219
- # # Create a separate legend for loss types
1220
- # loss_legend = plt.legend(title="Loss Type", loc="upper right", labels=["Compliment", "Cross Entropy", "Headless"])
1221
- # plt.gca().add_artist(loss_legend)
1222
-
1223
- # plt.tight_layout()
1224
- # return g.fig
1225
-
1226
-
1227
- # @render.image
1228
- # def image():
1229
- # img = None
1230
- # if input.plot_type() == "ColorSquare":
1231
- # img = {"src": f"color_square_{input.virus_selector()[0]}_0.png", "alt": "ColorSquare"}
1232
- # return img
1233
- # return img
 
 
 
 
1
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import matplotlib.pyplot as plt
 
 
3
  from scipy.interpolate import interp1d
4
+ from utils import (
5
+ filter_and_select,
6
+ plot_2d_comparison,
7
+ plot_color_square,
8
+ wens_method_heatmap,
9
+ plot_fcgr,
10
+ plot_persistence_homology,
11
+ )
 
 
 
 
 
 
 
 
 
12
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
 
16
  ############################################################# Virus Dataset ########################################################
 
24
  if len(group) >= 3:
25
  return group.head(3)
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  ############################################################# UI #################################################################
28
 
29
  ui.page_opts(fillable=True)
30
 
31
+ with ui.navset_card_tab(id="tab"):
32
  with ui.nav_panel("Viral Macrostructure"):
 
33
  ui.panel_title("Do viruses have underlying structure?")
34
  with ui.layout_columns():
35
  with ui.card():
36
+ ui.input_selectize("virus_selector", "Select your viruses:", virus, multiple=True, selected=None)
 
 
 
 
 
37
  with ui.card():
38
+ ui.input_selectize(
39
+ "plot_type_macro",
40
+ "Select your method:",
41
+ ["Chaos Game Representation", "2D Line", "ColorSquare", "Persistant Homology", "Wens Method"],
42
+ multiple=False,
43
+ selected=None,
44
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ @render.plot()
47
+ def plot_macro():
48
+ df = pd.read_parquet("virus_ds.parquet")
49
+ df = df[df["Organism_Name"].isin(input.virus_selector())]
50
+ grouped = df.groupby("Organism_Name")["Sequence"].apply(list)
51
+
52
+ plot_type = input.plot_type_macro()
53
+ if plot_type == "2D Line":
54
+ return plot_2d_comparison(grouped, grouped.index)
55
+ elif plot_type == "ColorSquare":
56
+ filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
57
+ return plot_color_square(filtered_df["Sequence"], filtered_df["Organism_Name"].unique())
58
+ elif plot_type == "Wens Method":
59
+ return wens_method_heatmap(df, df["Organism_Name"].unique())
60
+ elif plot_type == "Chaos Game Representation":
61
+ filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
62
+ return plot_fcgr(filtered_df["Sequence"], df["Organism_Name"].unique())
63
+ elif plot_type == "Persistant Homology":
64
+ filtered_df = df.groupby("Organism_Name").apply(filter_and_select).reset_index(drop=True)
65
+ return plot_persistence_homology(filtered_df["Sequence"], filtered_df["Organism_Name"])
66
+
67
+ with ui.nav_panel("Viral Microstructure"):
68
  ui.panel_title("Kmer Distribution")
69
  with ui.layout_columns():
70
  with ui.card():
71
  ui.input_slider("kmer", "kmer", 0, 10, 4)
72
  ui.input_slider("top_k", "top:", 0, 1000, 15)
73
+ ui.input_selectize("plot_type", "Select metric:", ["percentage", "count"], multiple=False, selected=None)
74
+
 
 
 
 
 
 
 
 
75
  @render.plot()
76
+ def plot_micro():
77
+ df = pd.read_csv("kmers.csv")
78
  k = input.kmer()
79
  top_k = input.top_k()
80
+ plot_type = input.plot_type()
81
+
82
+ if k > 0:
83
+ df = df[df["k"] == k].head(top_k)
 
 
 
 
 
 
 
 
 
 
84
  fig, ax = plt.subplots()
85
+ if plot_type == "count":
86
+ ax.bar(df["kmer"], df["count"])
87
+ ax.set_ylabel("Count")
88
+ elif plot_type == "percentage":
89
+ ax.bar(df["kmer"], df["percent"] * 100)
90
+ ax.set_ylabel("Percentage")
91
  ax.set_title(f"Most common {k}-mers")
92
  ax.set_xlabel("K-mer")
93
+ ax.set_xticklabels(df["kmer"], rotation=90)
94
+ return fig
95
+
96
+ with ui.nav_panel("Viral Model Training"):
 
 
97
  ui.panel_title("Does context size matter for a nucleotide model?")
98
+
99
+ def plot_loss_rates(df, model_type):
 
100
  x = np.linspace(0, 1, 1000)
101
  loss_rates = []
102
+ labels = ["32", "64", "128", "256", "512", "1024"]
103
+ df = df.drop(columns=["Step"])
 
104
  for col in df.columns:
105
+ y = df[col].dropna().astype("float", errors="ignore").values
106
  f = interp1d(np.linspace(0, 1, len(y)), y)
107
  loss_rates.append(f(x))
108
  fig, ax = plt.subplots()
109
  for i, loss_rate in enumerate(loss_rates):
110
  ax.plot(x, loss_rate, label=labels[i])
111
  ax.legend()
112
+ ax.set_title(f"Loss rates for a {model_type} parameter model across context windows")
113
+ ax.set_xlabel("Training steps")
114
+ ax.set_ylabel("Loss rate")
115
  return fig
116
+
 
117
  @render.image
118
  def plot_context_size_scaling():
119
+ df = pd.read_csv("14m.csv")
120
+ fig = plot_loss_rates(df, "14M")
 
 
 
 
121
  if fig:
122
+ import tempfile
123
+
124
+ fd, path = tempfile.mkstemp(suffix=".svg")
125
  fig.savefig(path)
126
+ return {"src": str(path), "width": "600px", "format": "svg"}
127
+
128
  with ui.nav_panel("Model loss analysis"):
 
129
  ui.panel_title("Neurips stuff")
 
130
  with ui.card():
131
  ui.input_selectize(
132
+ "param_type",
133
+ "Select Param Type:",
134
+ ["14", "31", "70", "160", "410"],
135
+ multiple=True,
136
+ selected=["14", "70"],
137
+ )
138
  ui.input_selectize(
139
+ "model_type",
140
+ "Select Model Type:",
141
+ ["pythia", "denseformer", "evo"],
142
+ multiple=True,
143
+ selected=["pythia", "denseformer"],
144
+ )
145
  ui.input_selectize(
146
+ "loss_type",
147
+ "Select Loss Type:",
148
+ ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
149
+ multiple=True,
150
+ selected=["compliment", "cross_entropy", "headless"],
151
+ )
152
+
153
  def plot_loss_rates_model(df, param_types, loss_types, model_types):
 
154
  x = np.linspace(0, 1, 1000)
155
  loss_rates = []
156
  labels = []
 
157
  for param_type in param_types:
158
  for loss_type in loss_types:
159
  for model_type in model_types:
160
+ y = df[
161
+ (df["param_type"] == int(param_type))
162
+ & (df["loss_type"] == loss_type)
163
+ & (df["model_type"] == model_type)
164
+ ]["loss_interp"].values
165
  if len(y) > 0:
166
  f = interp1d(np.linspace(0, 1, len(y)), y)
167
  loss_rates.append(f(x))
168
+ labels.append(f"{param_type}_{loss_type}_{model_type}")
 
169
  fig, ax = plt.subplots()
 
 
170
  for i, loss_rate in enumerate(loss_rates):
 
 
 
 
 
 
171
  ax.plot(x, loss_rate, label=labels[i])
 
 
172
  ax.legend()
173
+ ax.set_xlabel("Training steps")
174
+ ax.set_ylabel("Loss rate")
 
175
  return fig
176
+
 
177
  @render.image
178
  def plot_model_scaling():
179
+ df = pd.read_csv("training_data_5.csv")
180
+ df = df[df["epoch_interp"] > 0.035]
181
+ fig = plot_loss_rates_model(
182
+ df, input.param_type(), input.loss_type(), input.model_type()
183
+ )
 
 
 
184
  if fig:
185
+ import tempfile
186
+
187
+ fd, path = tempfile.mkstemp(suffix=".svg")
188
  fig.savefig(path)
189
+ return {"src": str(path), "width": "600px", "format": "svg"}
190
+
191
  with ui.nav_panel("Scaling Laws"):
 
192
  ui.panel_title("Params & Losses")
 
193
  with ui.card():
 
194
  ui.input_selectize(
195
+ "model_type_scale",
196
+ "Select Model Type:",
197
+ ["pythia", "denseformer", "evo"],
198
+ multiple=True,
199
+ selected=["evo", "denseformer"],
200
+ )
201
  ui.input_selectize(
202
+ "loss_type_scale",
203
+ "Select Loss Type:",
204
+ ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
205
+ multiple=True,
206
+ selected=["cross_entropy"],
207
+ )
208
+
209
  def plot_loss_rates_model_scale(df, loss_type, model_types):
210
+ df = df[df["loss_type"] == loss_type[0]]
 
211
  params = []
212
  loss_rates = []
213
  labels = []
214
  for model_type in model_types:
215
+ df_new = df[df["model_type"] == model_type]
216
  losses = []
217
  params_model = []
218
+ for paramy in df_new["num_params"].unique():
219
+ loss = df_new[df_new["num_params"] == paramy]["loss_interp"].min()
220
+ par = int(paramy)
221
+ losses.append(loss)
222
+ params_model.append(par)
223
+ df_reorder = pd.DataFrame({"loss": losses, "params": params_model})
224
+ df_reorder = df_reorder.sort_values(by="params")
225
+ loss_rates.append(df_reorder["loss"].to_list())
226
+ params.append(df_reorder["params"].to_list())
 
 
 
 
227
  labels.append(model_type)
 
228
  fig, ax = plt.subplots()
 
229
  for i, loss_rate in enumerate(loss_rates):
230
  ax.plot(params[i], loss_rate, label=labels[i])
 
231
  ax.legend()
232
+ ax.set_xlabel("Params")
233
+ ax.set_ylabel("Loss")
 
234
  return fig
235
+
 
 
236
  @render.image
237
  def plot_big_boy_model():
238
+ df = pd.read_csv("training_data_5.csv")
239
+ fig = plot_loss_rates_model_scale(
240
+ df, input.loss_type_scale(), input.model_type_scale()
241
+ )
 
 
242
  if fig:
243
+ import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
+ fd, path = tempfile.mkstemp(suffix=".svg")
246
+ fig.savefig(path)
247
+ return {"src": str(path), "width": "600px", "format": "svg"}