foz commited on
Commit
0dab632
1 Parent(s): 020d824

Add more options and tests

Browse files
app.py CHANGED
@@ -2,20 +2,19 @@
2
  import gradio as gr
3
  from tifffile import imread
4
  from PIL import Image
5
- import matplotlib.pyplot as plt
6
- from analyse import analyse_paths
7
  import numpy as np
8
 
9
- def process(cell_id, foci_file, traces_file):
10
- paths, traces, fig, extracted_peaks = analyse_paths(cell_id, foci_file.name, traces_file.name)
11
- extracted_peaks.to_csv('tmp')
12
- return paths, [Image.fromarray(im) for im in traces], fig, extracted_peaks, 'tmp'
13
 
 
14
  def preview_image(file1):
15
  if file1:
16
  im = imread(file1.name)
17
- print(im.shape)
18
- return Image.fromarray(np.max(im, axis=0))
 
 
 
19
  else:
20
  return None
21
 
@@ -23,12 +22,27 @@ def preview_image(file1):
23
  with gr.Blocks() as demo:
24
  with gr.Row():
25
  with gr.Column():
 
26
  cellid_input = gr.Textbox(label="Cell ID", placeholder="Image_1")
27
  image_input = gr.File(label="Input foci image")
28
  image_preview = gr.Image(label="Max projection of foci image")
29
  image_input.change(fn=preview_image, inputs=image_input, outputs=image_preview)
30
  path_input = gr.File(label="SNT traces file")
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  with gr.Column():
33
  trace_output = gr.Image(label="Overlayed paths")
34
  image_output=gr.Gallery(label="Traced paths")
@@ -36,9 +50,25 @@ with gr.Blocks() as demo:
36
  data_output=gr.DataFrame(label="Detected peak data")#, "Peak 1 pos", "Peak 1 int"])
37
  data_file_output=gr.File(label="Output data file (.csv)")
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  with gr.Row():
40
  greet_btn = gr.Button("Process")
41
- greet_btn.click(fn=process, inputs=[cellid_input, image_input, path_input], outputs=[trace_output, image_output, plot_output, data_output, data_file_output], api_name="process")
42
 
43
 
44
  if __name__ == "__main__":
 
2
  import gradio as gr
3
  from tifffile import imread
4
  from PIL import Image
5
+ from path_analysis.analyse import analyse_paths
 
6
  import numpy as np
7
 
 
 
 
 
8
 
9
+ # Function to preview the imported image
10
  def preview_image(file1):
11
  if file1:
12
  im = imread(file1.name)
13
+ print(im.ndim, im.shape)
14
+ if im.ndim>2:
15
+ return Image.fromarray(np.max(im, axis=0))
16
+ else:
17
+ return Image.fromarray(im)
18
  else:
19
  return None
20
 
 
22
  with gr.Blocks() as demo:
23
  with gr.Row():
24
  with gr.Column():
25
+ # Inputs for cell ID, image, and path
26
  cellid_input = gr.Textbox(label="Cell ID", placeholder="Image_1")
27
  image_input = gr.File(label="Input foci image")
28
  image_preview = gr.Image(label="Max projection of foci image")
29
  image_input.change(fn=preview_image, inputs=image_input, outputs=image_preview)
30
  path_input = gr.File(label="SNT traces file")
31
 
32
+ # Additional options wrapped in an accordion for better UI experience
33
+ with gr.Accordion("Additional options ..."):
34
+ sphere_radius = gr.Number(label="Trace sphere radius (um)", value=0.1984125, interactive=True)
35
+ peak_threshold = gr.Number(label="Peak relative threshold", value=0.4, interactive=True)
36
+ # Resolutions for xy and z axis
37
+ with gr.Row():
38
+ xy_res = gr.Number(label='xy-yesolution (um)', value=0.0396825, interactive=True)
39
+ z_res = gr.Number(label='z resolution (um)', value=0.0909184, interactive=True)
40
+ # Resolutions for xy and z axis
41
+
42
+ threshold_type = gr.Radio(["per-trace", "per-cell"], label="Threshold-type", value="per-trace", interactive=True)
43
+
44
+
45
+ # The output column showing the result of processing
46
  with gr.Column():
47
  trace_output = gr.Image(label="Overlayed paths")
48
  image_output=gr.Gallery(label="Traced paths")
 
50
  data_output=gr.DataFrame(label="Detected peak data")#, "Peak 1 pos", "Peak 1 int"])
51
  data_file_output=gr.File(label="Output data file (.csv)")
52
 
53
+
54
+ def process(cellid_input, image_input, path_input, sphere_radius, peak_threshold, xy_res, z_res, threshold_type):
55
+
56
+ config = { 'sphere_radius': sphere_radius,
57
+ 'peak_threshold': peak_threshold,
58
+ 'xy_res': xy_res,
59
+ 'z_res': z_res,
60
+ 'threshold_type': threshold_type }
61
+
62
+
63
+ paths, traces, fig, extracted_peaks = analyse_paths(cellid_input, image_input.name, path_input.name, config)
64
+ extracted_peaks.to_csv('output.csv')
65
+ print('extracted', extracted_peaks)
66
+ return paths, [Image.fromarray(im) for im in traces], fig, extracted_peaks, 'output.csv'
67
+
68
+
69
  with gr.Row():
70
  greet_btn = gr.Button("Process")
71
+ greet_btn.click(fn=process, inputs=[cellid_input, image_input, path_input, sphere_radius, peak_threshold, xy_res, z_res, threshold_type], outputs=[trace_output, image_output, plot_output, data_output, data_file_output], api_name="process")
72
 
73
 
74
  if __name__ == "__main__":
path_analysis/__init__.py ADDED
File without changes
path_analysis/analyse.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import lxml.etree as ET
3
+ import gzip
4
+ import tifffile
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from PIL import Image, ImageDraw
8
+ import pandas as pd
9
+ from itertools import cycle
10
+ from .data_preprocess import analyse_traces
11
+ import math
12
+ import scipy.linalg as la
13
+
14
+
15
+ def get_paths_from_traces_file(traces_file):
16
+ """
17
+ Parses the specified traces file and extracts paths and their lengths.
18
+
19
+ Args:
20
+ traces_file (str): Path to the XML traces file.
21
+
22
+ Returns:
23
+ tuple: A tuple containing a list of paths (each path is a list of tuples representing points)
24
+ and a list of corresponding path lengths.
25
+ """
26
+ tree = ET.parse(traces_file)
27
+ root = tree.getroot()
28
+ all_paths = []
29
+ path_lengths = []
30
+ for path in root.findall('path'):
31
+ length=path.get('reallength')
32
+ path_points = []
33
+ for point in path:
34
+ path_points.append((int(point.get('x')), int(point.get('y')), int(point.get('z'))))
35
+ all_paths.append(path_points)
36
+ path_lengths.append(float(length))
37
+ return all_paths, path_lengths
38
+
39
+ def calculate_path_length(point_list, voxel_size=(1,1,1)):
40
+ # Simple calculation
41
+ l = 0
42
+ s = np.array(voxel_size)
43
+ for i in range(len(point_list)-1):
44
+ l += la.norm(s * (np.array(point_list[i+1]) - np.array(point_list[i])))
45
+ return l
46
+
47
+
48
+ def calculate_path_length_partials(point_list, voxel_size=(1,1,1)):
49
+ # Simple calculation
50
+ l = [0.0]
51
+ s = np.array(voxel_size)
52
+ for i in range(len(point_list)-1):
53
+ l.append(la.norm(s * (np.array(point_list[i+1]) - np.array(point_list[i]))))
54
+ return np.cumsum(l)
55
+
56
+
57
+ def visualise_ordering(points_list, dim, wr=5, wc=5):
58
+ """
59
+ Visualize the ordering of points in an image.
60
+
61
+ Args:
62
+ points_list (list): List of points to be visualized.
63
+ dim (tuple): Dimensions of the image (rows, columns, channels).
64
+ wr (int, optional): Width of the region to visualize around the point in the row direction. Defaults to 5.
65
+ wc (int, optional): Width of the region to visualize around the point in the column direction. Defaults to 5.
66
+
67
+ Returns:
68
+ np.array: An image array with visualized points.
69
+ """
70
+ # Visualizes the ordering of the points in the list on a blank image.
71
+ rdim, cdim, _ = dim
72
+ vis = np.zeros((rdim, cdim, 3), dtype=np.uint8)
73
+
74
+ def get_col(i):
75
+ r = int(255 * i/len(points_list))
76
+ g = 255 - r
77
+ return r, g, 0
78
+
79
+ for n, p in enumerate(points_list):
80
+ c, r, _ = map(int, p)
81
+ vis[max(0,r-wr):min(rdim,r+wr+1),max(0,c-wc):min(cdim,c+wc+1)] = get_col(n)
82
+
83
+ return vis
84
+
85
+ # A color map for paths
86
+ col_map = [(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255), (0,255,255),
87
+ (255,127,0), (255, 0, 127), (127, 255, 0), (0, 255, 127), (127,0,255), (0,127,255)]
88
+
89
+ def draw_paths(all_paths, foci_stack, foci_index=None, r=3):
90
+ """
91
+ Draws paths on the provided image stack and overlays markers for the foci
92
+
93
+ Args:
94
+ all_paths (list): List of paths where each path is a list of points.
95
+ foci_stack (np.array): 3D numpy array representing the image stack.
96
+ foci_index (list, optional): List of list of focus indices (along each path). Defaults to None.
97
+ r (int, optional): Radius for the ellipse or line drawing around the focus. Defaults to 3.
98
+
99
+ Returns:
100
+ PIL.Image.Image: An image with the drawn paths.
101
+ """
102
+ im = np.max(foci_stack, axis=0)
103
+ im = (im/np.max(im)*255).astype(np.uint8)
104
+ im = np.dstack((im,)*3)
105
+ im = Image.fromarray(im)
106
+ draw = ImageDraw.Draw(im)
107
+ for i, (p, col) in enumerate(zip(all_paths, cycle(col_map))):
108
+ draw.line([(u[0], u[1]) for u in p], fill=col)
109
+ draw.text((p[0][0], p[0][1]), str(i+1), fill=col)
110
+ if foci_index is not None:
111
+ for i, (idx, p, col) in enumerate(zip(foci_index, all_paths, cycle(col_map))):
112
+ if len(idx):
113
+ for j in idx:
114
+ draw.line((int(p[j][0]-r), int(p[j][1]), int(p[j][0]+r), int(p[j][1])), fill=col, width=2)
115
+ draw.line((int(p[j][0]), int(p[j][1]-r), int(p[j][0]), int(p[j][1]+r)), fill=col, width=2)
116
+
117
+ return im
118
+
119
+
120
+ def measure_from_mask(mask, measure_stack):
121
+ """
122
+ Compute the sum of measure_stack values where the mask is equal to 1.
123
+
124
+ Args:
125
+ mask (numpy.ndarray): Binary mask where the measurement should be applied.
126
+ measure_stack (numpy.ndarray): Stack of measurements.
127
+
128
+ Returns:
129
+ measure_stack.dtype: Sum of measure_stack values where the mask is 1.
130
+ """
131
+ return np.sum(mask * measure_stack)
132
+
133
+ # Max of measure_stack over region where mask==1
134
+ def max_from_mask(mask, measure_stack):
135
+ """
136
+ Compute the maximum of measure_stack values where the mask is equal to 1.
137
+
138
+ Args:
139
+ mask (numpy.ndarray): Binary mask where the measurement should be applied.
140
+ measure_stack (numpy.ndarray): Stack of measurements.
141
+
142
+ Returns:
143
+ measure_stack.dtype: Maximum value of measure_stack where the mask is 1.
144
+ """
145
+ return np.max(mask * measure_stack)
146
+
147
+ def make_mask_s(p, melem, measure_stack):
148
+ """
149
+ Translate a mask to point p, ensuring correct treatment near the edges of the measure_stack.
150
+
151
+ Args:
152
+ p (tuple): Target point (r, c, z).
153
+ melem (numpy.ndarray): Structuring element for the mask.
154
+ measure_stack (numpy.ndarray): Stack of measurements.
155
+
156
+ Returns:
157
+ tuple: A tuple containing the translated mask and a section of the measure_stack.
158
+ """
159
+
160
+
161
+ #
162
+
163
+ R = [u//2 for u in melem.shape]
164
+
165
+
166
+ r, c, z = p
167
+
168
+ mask = np.zeros(melem.shape)
169
+
170
+ m_data = np.zeros(melem.shape)
171
+ s = measure_stack.shape
172
+ o_1, o_2, o_3 = max(R[0]-r, 0), max(R[1]-c, 0), max(R[2]-z,0)
173
+ e_1, e_2, e_3 = min(R[0]-r+s[0], 2*R[0]+1), min(R[1]-c+s[1], 2*R[1]+1), min(R[2]-z+s[2], 2*R[2]+1)
174
+ m_data[o_1:e_1,o_2:e_2,o_3:e_3] = measure_stack[max(r-R[0],0):min(r+R[0]+1,s[0]),max(c-R[1],0):min(c+R[1]+1,s[1]),max(z-R[2],0):min(z+R[2]+1, s[2])]
175
+ mask[o_1:e_1,o_2:e_2,o_3:e_3] = melem[o_1:e_1,o_2:e_2,o_3:e_3]
176
+
177
+
178
+ return mask, m_data
179
+
180
+
181
+ def measure_at_point(p, melem, measure_stack, op='mean'):
182
+ """
183
+ Measure the mean or max value of measure_stack around a specific point using a structuring element.
184
+
185
+ Args:
186
+ p (tuple): Target point (r, c, z).
187
+ melem (numpy.ndarray): Structuring element for the mask.
188
+ measure_stack (numpy.ndarray): Stack of measurements.
189
+ op (str, optional): Operation to be applied; either 'mean' or 'max'. Default is 'mean'.
190
+
191
+ Returns:
192
+ float: Measured value based on the specified operation.
193
+ """
194
+
195
+ p = map(int, p)
196
+ if op=='mean':
197
+ mask, m_data = make_mask_s(p, melem, measure_stack)
198
+ melem_size = np.sum(mask)
199
+ return float(measure_from_mask(mask, m_data) / melem_size)
200
+ else:
201
+ mask, m_data = make_mask_s(p, melem, measure_stack)
202
+ return float(max_from_mask(mask, m_data))
203
+
204
+ # Generate spherical region
205
+ def make_sphere(R=5, z_scale_ratio=2.3):
206
+ """
207
+ Generate a binary representation of a sphere in 3D space.
208
+
209
+ Args:
210
+ R (int, optional): Radius of the sphere. Default is 5.
211
+ z_scale_ratio (float, optional): Scaling factor for the z-axis. Default is 2.3.
212
+
213
+ Returns:
214
+ numpy.ndarray: Binary representation of the sphere.
215
+ """
216
+ R_z = int(math.ceil(R/z_scale_ratio))
217
+ x, y, z = np.ogrid[-R:R+1, -R:R+1, -R_z:R_z+1]
218
+ sphere = x**2 + y**2 + (z_scale_ratio * z)**2 <= R**2
219
+ return sphere
220
+
221
+ # Measure the values of measure_stack at each of the points of points_list in turn.
222
+ # Measurement is the mean / max (specified by op) on the spherical region about each point
223
+ def measure_all_with_sphere(points_list, measure_stack, op='mean', R=5, z_scale_ratio=2.3):
224
+ """
225
+ Measure the values of measure_stack at each point in a list using a spherical region.
226
+
227
+ Args:
228
+ points_list (list): List of points (r, c, z) to be measured.
229
+ measure_stack (numpy.ndarray): Stack of measurements.
230
+ op (str, optional): Operation to be applied; either 'mean' or 'max'. Default is 'mean'.
231
+ R (int, optional): Radius of the sphere. Default is 5.
232
+ z_scale_ratio (float, optional): Scaling factor for the z-axis. Default is 2.3.
233
+
234
+ Returns:
235
+ list: List of measured values for each point.
236
+ """
237
+ melem = make_sphere(R, z_scale_ratio)
238
+ measure_func = lambda p: measure_at_point(p, melem, measure_stack, op)
239
+ return list(map(measure_func, points_list))
240
+
241
+
242
+ # Measure fluorescence levels along ordered skeleton
243
+ def measure_chrom2(path, hei10, config):
244
+ """
245
+ Measure fluorescence levels along an ordered skeleton.
246
+
247
+ Args:
248
+ path (list): List of ordered path points (r, c, z).
249
+ hei10 (numpy.ndarray): 3D fluorescence data.
250
+ config (dict): Configuration dictionary containing 'z_res', 'xy_res', and 'sphere_radius' values.
251
+
252
+ Returns:
253
+ tuple: A tuple containing the visualization, mean measurements, and max measurements along the path.
254
+ """
255
+ scale_ratio = config['z_res']/config['xy_res']
256
+ sphere_xy_radius = int(math.ceil(config['sphere_radius']/config['xy_res']))
257
+
258
+ vis = visualise_ordering(path, dim=hei10.shape, wr=sphere_xy_radius, wc=sphere_xy_radius)
259
+
260
+ measurements = measure_all_with_sphere(path, hei10, op='mean', R=sphere_xy_radius, z_scale_ratio=scale_ratio)
261
+ measurements_max = measure_all_with_sphere(path, hei10, op='max', R=sphere_xy_radius, z_scale_ratio=scale_ratio)
262
+
263
+ return vis, measurements, measurements_max
264
+
265
+ def extract_peaks(cell_id, all_paths, path_lengths, measured_traces, config):
266
+ """
267
+ Extract peak information from given traces and compile them into a DataFrame.
268
+
269
+ Args:
270
+ - cell_id (int or str): Identifier for the cell being analyzed.
271
+ - all_paths (list of lists): Contains ordered path points for multiple paths.
272
+ - path_lengths (list of floats): List containing lengths of each path in all_paths.
273
+ - measured_traces (list of lists): Contains fluorescence measurement values along the paths.
274
+ - config (dict): Configuration dictionary containing:
275
+ - 'peak_threshold': Threshold value to determine a peak in the trace.
276
+ - 'sphere_radius': Radius of the sphere used in fluorescence measurement.
277
+
278
+ Returns:
279
+ - pd.DataFrame: DataFrame containing peak information for each path.
280
+ - list of lists: Absolute intensities of the detected foci.
281
+ - list of lists: Index positions of the detected foci.
282
+ - list of lists: Absolute focus intensity threshold for each trace.
283
+ - list of numpy.ndarray: For each trace, distances of each point from start of trace in microns
284
+ """
285
+
286
+ n_paths = len(all_paths)
287
+
288
+ data = []
289
+ foci_absolute_intensity, foci_position, foci_position_index, trace_median_intensities, trace_thresholds = analyse_traces(all_paths, path_lengths, measured_traces, config)
290
+
291
+ total_intensity = sum(sum(path_foci_abs_int - tmi) for path_foci_abs_int, tmi in zip(foci_absolute_intensity, trace_median_intensities))
292
+ trace_positions = []
293
+
294
+ for i in range(n_paths):
295
+
296
+ pl = calculate_path_length_partials(all_paths[i], (config['xy_res'], config['xy_res'], config['z_res']))
297
+
298
+ path_data = { 'Cell_ID':cell_id,
299
+ 'Trace': i+1,
300
+ 'SNT_trace_length(um)': path_lengths[i],
301
+ 'Measured_trace_length(um)': pl[-1],
302
+ 'Trace_median_intensity': trace_median_intensities[i],
303
+ 'Detection_sphere_radius(um)': config['sphere_radius'],
304
+ 'Foci_ID_threshold': config['peak_threshold'] }
305
+ for j, (idx, u,v) in enumerate(zip(foci_position_index[i], foci_position[i], foci_absolute_intensity[i])):
306
+ if config['use_corrected_positions']:
307
+ path_data[f'Foci_{j+1}_position(um)'] = pl[idx]
308
+ else:
309
+ path_data[f'Foci_{j+1}_position(um)'] = u
310
+ path_data[f'Foci_{j+1}_absolute_intensity'] = v
311
+ path_data[f'Foci_{j+1}_relative_intensity'] = (v - trace_median_intensities[i])/total_intensity
312
+ data.append(path_data)
313
+ trace_positions.append(pl)
314
+ return pd.DataFrame(data), foci_absolute_intensity, foci_position_index, trace_thresholds, trace_positions
315
+
316
+
317
+ def analyse_paths(cell_id,
318
+ foci_file,
319
+ traces_file,
320
+ config
321
+ ):
322
+ """
323
+ Analyzes paths for the given cell ID using provided foci and trace files.
324
+
325
+ Args:
326
+ cell_id (int/str): Identifier for the cell.
327
+ foci_file (str): Path to the foci image file.
328
+ traces_file (str): Path to the XML traces file.
329
+ config (dict): Configuration dictionary containing necessary parameters such as resolutions and thresholds.
330
+
331
+ Returns:
332
+ tuple: A tuple containing an overlay image of the traces, visualization images for each trace,
333
+ a figure with plotted measurements, and a dataframe with extracted peaks.
334
+ """
335
+
336
+
337
+ foci_stack = tifffile.imread(foci_file)
338
+
339
+ if foci_stack.ndim==2:
340
+ foci_stack = foci_stack[None,:,:]
341
+
342
+ all_paths, path_lengths = get_paths_from_traces_file(traces_file)
343
+
344
+ all_trace_vis = []
345
+ all_m = []
346
+ for p in all_paths:
347
+ vis, m, _ = measure_chrom2(p,foci_stack.transpose(2,1,0), config)
348
+ all_trace_vis.append(vis)
349
+ all_m.append(m)
350
+
351
+
352
+ extracted_peaks, foci_absolute_intensity, foci_pos_index, trace_thresholds, trace_positions = extract_peaks(cell_id, all_paths, path_lengths, all_m, config)
353
+
354
+
355
+ n_cols = 2
356
+ n_rows = (len(all_paths)+n_cols-1)//n_cols
357
+ fig, ax = plt.subplots(n_rows,n_cols)
358
+ ax = ax.flatten()
359
+
360
+ for i, m in enumerate(all_m):
361
+ ax[i].set_title(f'Trace {i+1}')
362
+ ax[i].plot(trace_positions[i], m)
363
+ print(foci_pos_index)
364
+ if len(foci_pos_index[i]):
365
+ ax[i].plot(trace_positions[i][foci_pos_index[i]], np.array(m)[foci_pos_index[i]], 'rx')
366
+ ax[i].set_xlabel('Distance from start (um)')
367
+ ax[i].set_ylabel('Intensity')
368
+ ax[i].axhline(trace_thresholds[i], c='r', ls=':')
369
+ for i in range(len(all_m), n_cols*n_rows):
370
+ ax[i].axis('off')
371
+
372
+ plt.tight_layout()
373
+ trace_overlay = draw_paths(all_paths, foci_stack, foci_index=foci_pos_index)
374
+
375
+ return trace_overlay, all_trace_vis, fig, extracted_peaks
path_analysis/data_preprocess.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import scipy.linalg as la
6
+ from scipy.signal import find_peaks
7
+ from math import ceil
8
+
9
+
10
+
11
+
12
+ def thin_points(point_list, dmin=10, voxel_size=(1,1,1)):
13
+ """
14
+ Remove points within a specified distance of each other, retaining the point with the highest intensity.
15
+
16
+ Args:
17
+ - point_list (list of tuples): Each tuple contains:
18
+ - x (list of float): 3D coordinates of the point.
19
+ - intensity (float): The intensity value of the point.
20
+ - idx (int): A unique identifier or index for the point.
21
+ - dmin (float, optional): Minimum distance between points. Points closer than this threshold will be thinned. Defaults to 10.
22
+
23
+ Returns:
24
+ - list of int: A list containing indices of the removed points.
25
+
26
+ Notes:
27
+ - The function uses the L2 norm (Euclidean distance) to compute the distance between points.
28
+ - When two points are within `dmin` distance, the point with the lower intensity is removed.
29
+ """
30
+ removed_points = []
31
+ for i in range(len(point_list)):
32
+ if point_list[i][2] in removed_points:
33
+ continue
34
+ for j in range(len(point_list)):
35
+ if i==j:
36
+ continue
37
+ if point_list[j][2] in removed_points:
38
+ continue
39
+ d = (np.array(point_list[i][0]) - np.array(point_list[j][0]))*np.array(voxel_size)
40
+ d = la.norm(d)
41
+ if d<dmin:
42
+ hi = point_list[i][1]
43
+ hj = point_list[j][1]
44
+ if hi<hj:
45
+ removed_points.append(point_list[i][2])
46
+ break
47
+ else:
48
+ removed_points.append(point_list[j][2])
49
+
50
+ return removed_points
51
+
52
+
53
+ @dataclass
54
+ class CellData(object):
55
+ """Represents data related to a single cell.
56
+
57
+ Attributes:
58
+ pathdata_list (list): A list of PathData objects representing the various paths associated with the cell.
59
+ """
60
+ pathdata_list: list
61
+
62
+ @dataclass
63
+ class PathData(object):
64
+ """Represents data related to a specific path in the cell.
65
+
66
+ This dataclass encapsulates information about the peaks,
67
+ the defining points, the fluorescence values, and the path length of a specific path.
68
+
69
+ Attributes: peaks (list): List of peaks in the path (indicies of positions in points, o_hei10).
70
+ points (list): List of points defining the path.
71
+ o_hei10 (list): List of (unnormalized) fluorescence intensity values along the path
72
+ SC_length (float): Length of the path.
73
+
74
+ """
75
+ peaks: list
76
+ points: list
77
+ o_hei10: list
78
+ SC_length: float
79
+
80
+
81
+
82
+ def find_peaks2(v, distance=5, prominence=0.5):
83
+ """
84
+ Find peaks in a 1D array with extended boundary handling.
85
+
86
+ The function pads the input array at both ends to handle boundary peaks. It then identifies peaks in the extended array
87
+ and maps them back to the original input array.
88
+
89
+ Args:
90
+ - v (numpy.ndarray): 1D input array in which to find peaks.
91
+ - distance (int, optional): Minimum number of array elements that separate two peaks. Defaults to 5.
92
+ - prominence (float, optional): Minimum prominence required for a peak to be identified. Defaults to 0.5.
93
+
94
+ Returns:
95
+ - list of int: List containing the indices of the identified peaks in the original input array.
96
+ - dict: Information about the properties of the identified peaks (as returned by scipy.signal.find_peaks).
97
+
98
+ """
99
+ pad = int(ceil(distance))+1
100
+ v_ext = np.concatenate([np.ones((pad,), dtype=v.dtype)*np.min(v), v, np.ones((pad,), dtype=v.dtype)*np.min(v)])
101
+
102
+ assert(len(v_ext) == len(v)+2*pad)
103
+ peaks, _ = find_peaks(v_ext, distance=distance, prominence=prominence)
104
+ peaks = peaks - pad
105
+ n_peaks = []
106
+ for i in peaks:
107
+ if 0<=i<len(v):
108
+ n_peaks.append(i)
109
+ else:
110
+ raise Exception
111
+ return n_peaks, _
112
+
113
+
114
+ def process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence):
115
+ """
116
+ Process traces of cells to extract peak information and organize the data.
117
+
118
+ The function normalizes fluorescence data, finds peaks, refines peak information,
119
+ removes unwanted peaks that might be due to close proximity of bright peaks from
120
+ other paths, and organizes all the information into a structured data format.
121
+
122
+ Args:
123
+ all_paths (list of list of tuples): A list containing paths, where each path is
124
+ represented as a list of 3D coordinate tuples.
125
+ path_lengths (list of float): List of path lengths corresponding to the provided paths.
126
+ measured_trace_fluorescence (list of list of float): A list containing fluorescence
127
+ data corresponding to each path point.
128
+
129
+ Returns:
130
+ CellData: An object containing organized peak and path data for a given cell.
131
+
132
+ Note:
133
+ - The function assumes that each path and its corresponding length and fluorescence data
134
+ are positioned at the same index in their respective lists.
135
+ """
136
+
137
+ cell_peaks = []
138
+
139
+ for points, path_length, o_hei10 in zip(all_paths, path_lengths, measured_trace_fluorescence):
140
+
141
+ # For peak determination normalize each trace to have mean zero and s.d. 1
142
+ hei10_normalized = (o_hei10 - np.mean(o_hei10))/np.std(o_hei10)
143
+
144
+ # Find peaks - these will be further refined later
145
+ p,_ = find_peaks2(hei10_normalized, distance=5, prominence=0.5*np.std(hei10_normalized))
146
+ peaks = np.array(p, dtype=np.int32)
147
+
148
+ # Store peak data - using original values, not normalized ones
149
+ peak_mean_heights = [ o_hei10[u] for u in peaks ]
150
+ peak_points = [ points[u] for u in peaks ]
151
+
152
+ cell_peaks.append((peaks, peak_points, peak_mean_heights))
153
+
154
+ # Eliminate peaks which have another larger peak nearby (in 3D space, on any chromosome).
155
+ # This aims to remove small peaks in the mean intensity generated when an SC passes close
156
+ # to a bright peak on another SC - this is nearby in space, but brighter.
157
+
158
+ to_thin = []
159
+ for k in range(len(cell_peaks)):
160
+ for u in range(len(cell_peaks[k][0])):
161
+ to_thin.append((cell_peaks[k][1][u], cell_peaks[k][2][u], (k, u)))
162
+
163
+ # Exclude any peak with a nearby brighter peak (on any SC)
164
+ removed_points = thin_points(to_thin)
165
+
166
+
167
+ # Clean up and remove these peaks
168
+ new_cell_peaks = []
169
+ for k in range(len(cell_peaks)):
170
+ cc = []
171
+ pp = cell_peaks[k][0]
172
+ for u in range(len(pp)):
173
+ if (k,u) not in removed_points:
174
+ cc.append(pp[u])
175
+ new_cell_peaks.append(cc)
176
+
177
+ cell_peaks = new_cell_peaks
178
+
179
+ pd_list = []
180
+
181
+ # Save peak positions, absolute HEI10 intensities, and length for each SC
182
+ for k in range(len(all_paths)):
183
+
184
+ points, o_hei10 = all_paths[k], measured_trace_fluorescence[k]
185
+
186
+ peaks = cell_peaks[k]
187
+
188
+ pd = PathData(peaks=peaks, points=points, o_hei10=o_hei10, SC_length=path_lengths[k])
189
+ pd_list.append(pd)
190
+
191
+ cd = CellData(pathdata_list=pd_list)
192
+
193
+ return cd
194
+
195
+
196
+ alpha_max = 0.4
197
+
198
+
199
+ # Criterion used for identifying peak as a CO - normalized (with mean and s.d.)
200
+ # hei10 levels being above 0.4 time maximum peak level
201
+ def pc(pos, v, alpha=alpha_max):
202
+ """
203
+ Identify and return positions where values in the array `v` exceed a certain threshold.
204
+
205
+ The threshold is computed as `alpha` times the maximum value in `v`.
206
+
207
+ Args:
208
+ - pos (numpy.ndarray): Array of positions.
209
+ - v (numpy.ndarray): 1D array of values, e.g., intensities.
210
+ - alpha (float, optional): A scaling factor for the threshold. Defaults to `alpha_max`.
211
+
212
+ Returns:
213
+ - numpy.ndarray: Array of positions where corresponding values in `v` exceed the threshold.
214
+ """
215
+ idx = (v>=alpha*np.max(v))
216
+ return np.array(pos[idx])
217
+
218
+ def analyse_celldata(cell_data, config):
219
+ """
220
+ Analyse the provided cell data to extract focus-related information.
221
+
222
+ Args:
223
+ cd (CellData): An instance of the CellData class containing path data information.
224
+ config (dictionary): Configuration dictionary containing 'peak_threshold' and 'threshold_type'
225
+ 'peak_threshold' (float) - threshold for calling peaks as foci
226
+ 'threshold_type' (str) = 'per-trace', 'per-foci'
227
+
228
+ Returns:
229
+ tuple: A tuple containing three lists:
230
+ - foci_rel_intensity (list): List of relative intensities for the detected foci.
231
+ - foci_pos (list): List of absolute positions of the detected foci.
232
+ - foci_pos_index (list): List of indices of the detected foci.
233
+ """
234
+ foci_abs_intensity = []
235
+ foci_pos = []
236
+ foci_pos_index = []
237
+ trace_median_intensities = []
238
+ trace_thresholds = []
239
+
240
+ peak_threshold = config['peak_threshold']
241
+
242
+ threshold_type = config['threshold_type']
243
+
244
+ if threshold_type == 'per-trace':
245
+ """
246
+ Call extracted peaks as foci if intensity - trace_mean > peak_threshold * (trace_max_foci_intensity - trace_mean)
247
+ """
248
+
249
+ for path_data in cell_data.pathdata_list:
250
+ peaks = np.array(path_data.peaks, dtype=np.int32)
251
+
252
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
253
+ # by standard deviation - note that the latter should have no effect on the results).
254
+ h = np.array(path_data.o_hei10)
255
+ h = h - np.mean(h)
256
+ h = h/np.std(h)
257
+ # Extract peaks according to criterion
258
+ sig_peak_idx = pc(peaks, h[peaks], peak_threshold)
259
+ trace_thresholds.append((1-peak_threshold)*np.mean(path_data.o_hei10) + peak_threshold*np.max(np.array(path_data.o_hei10)[peaks]))
260
+
261
+ pos_abs = (sig_peak_idx/len(path_data.points))*path_data.SC_length
262
+ foci_pos.append(pos_abs)
263
+ foci_abs_intensity.append(np.array(path_data.o_hei10)[sig_peak_idx])
264
+
265
+ foci_pos_index.append(sig_peak_idx)
266
+ trace_median_intensities.append(np.median(path_data.o_hei10))
267
+
268
+ elif threshold_type == 'per-cell':
269
+ """
270
+ Call extracted peaks as foci if intensity - trace_mean > peak_threshold * max(intensity - trace_mean)
271
+ """
272
+ max_cell_intensity = float("-inf")
273
+ for path_data in cell_data.pathdata_list:
274
+
275
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
276
+ # by standard deviation - note that the latter should have no effect on the results).
277
+ h = np.array(path_data.o_hei10)
278
+ h = h - np.mean(h)
279
+ max_cell_intensity = max(max_cell_intensity, np.max(h))
280
+
281
+ for path_data in cell_data.pathdata_list:
282
+ peaks = np.array(path_data.peaks, dtype=np.int32)
283
+
284
+ # Normalize extracted fluorescent intensities by subtracting mean (and dividing
285
+ # by standard deviation - note that the latter should have no effect on the results).
286
+ h = np.array(path_data.o_hei10)
287
+ h = h - np.mean(h)
288
+
289
+ sig_peak_idx = peaks[h[peaks]>peak_threshold*max_cell_intensity]
290
+
291
+ trace_thresholds.append(np.mean(path_data.o_hei10) + peak_threshold*max_cell_intensity)
292
+
293
+
294
+ pos_abs = (sig_peak_idx/len(path_data.points))*path_data.SC_length
295
+ foci_pos.append(pos_abs)
296
+ foci_abs_intensity.append(np.array(path_data.o_hei10)[sig_peak_idx])
297
+
298
+ foci_pos_index.append(sig_peak_idx)
299
+ trace_median_intensities.append(np.median(path_data.o_hei10))
300
+
301
+ else:
302
+ raise NotImplementedError
303
+
304
+ return foci_abs_intensity, foci_pos, foci_pos_index, trace_median_intensities, trace_thresholds
305
+
306
+ def analyse_traces(all_paths, path_lengths, measured_trace_fluorescence, config):
307
+
308
+ cd = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
309
+
310
+ return analyse_celldata(cd, config)
311
+
312
+
313
+
314
+
setup.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='path_analysis',
5
+ version='0.1.0',
6
+ description='A brief description of your package',
7
+ author='Your Name',
8
+ author_email='[email protected]',
9
+ url='https://github.com/yourusername/yourrepository', # if you have a repo for the project
10
+ packages=find_packages(), # or specify manually: ['your_package', 'your_package.submodule', ...]
11
+ install_requires=[
12
+ 'numpy', # for example, if your package needs numpy
13
+ 'gradio',
14
+ # ... other dependencies
15
+ ],
16
+ classifiers=[
17
+ 'Development Status :: 3 - Alpha',
18
+ 'Intended Audience :: Developers',
19
+ 'Programming Language :: Python :: 3',
20
+ 'Programming Language :: Python :: 3.6',
21
+ 'Programming Language :: Python :: 3.7',
22
+ 'Programming Language :: Python :: 3.8',
23
+ 'Programming Language :: Python :: 3.9',
24
+ # ... other classifiers
25
+ ],
26
+ python_requires='>=3.6', # your project's Python version requirement
27
+ keywords='some keywords related to your project',
28
+ # ... other parameters
29
+ )
tests/__init__.py ADDED
File without changes
tests/test_analyse.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from path_analysis.analyse import *
3
+ import numpy as np
4
+ from math import pi
5
+ import xml.etree.ElementTree as ET
6
+
7
+
8
+ def test_get_paths_from_traces_file():
9
+ # Mock the XML traces file content
10
+ xml_content = '''<?xml version="1.0"?>
11
+ <root>
12
+ <path reallength="5.0">
13
+ <point x="1" y="2" z="3"/>
14
+ <point x="4" y="5" z="6"/>
15
+ </path>
16
+ <path reallength="10.0">
17
+ <point x="7" y="8" z="9"/>
18
+ <point x="10" y="11" z="12"/>
19
+ </path>
20
+ </root>
21
+ '''
22
+
23
+ # Create a temporary XML file
24
+ with open("temp_traces.xml", "w") as f:
25
+ f.write(xml_content)
26
+
27
+ all_paths, path_lengths = get_paths_from_traces_file("temp_traces.xml")
28
+
29
+ expected_paths = [[(1, 2, 3), (4, 5, 6)], [(7, 8, 9), (10, 11, 12)]]
30
+ expected_lengths = [5.0, 10.0]
31
+
32
+ assert all_paths == expected_paths, f"Expected paths {expected_paths}, but got {all_paths}"
33
+ assert path_lengths == expected_lengths, f"Expected lengths {expected_lengths}, but got {path_lengths}"
34
+
35
+ # Clean up temporary file
36
+ import os
37
+ os.remove("temp_traces.xml")
38
+
39
+
40
+ def test_measure_chrom2():
41
+ # Mock data
42
+ path = [(2, 3, 4), (4, 5, 6), (9, 9, 9)] # Sample ordered path points
43
+ hei10 = np.random.rand(10, 10, 10) # Random 3D fluorescence data
44
+ config = {
45
+ 'z_res': 1,
46
+ 'xy_res': 0.5,
47
+ 'sphere_radius': 2.5
48
+ }
49
+
50
+ # Function call
51
+ _, measurements, measurements_max = measure_chrom2(path, hei10, config)
52
+
53
+ # Assertions
54
+ assert len(measurements) == len(path), "Measurements length should match path length"
55
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
56
+ assert all(0 <= val <= 1 for val in measurements), "All mean measurements should be between 0 and 1 for this mock data"
57
+ assert all(0 <= val <= 1 for val in measurements_max), "All max measurements should be between 0 and 1 for this mock data"
58
+
59
+ def test_measure_chrom2_z():
60
+ # Mock data
61
+ path = [(2, 3, 4), (4, 5, 6)] # Sample ordered path points
62
+ _,_,hei10 = np.meshgrid(np.arange(10), np.arange(10), np.arange(10)) # 3D fluorescence data - z dependent
63
+ config = {
64
+ 'z_res': 1,
65
+ 'xy_res': 0.5,
66
+ 'sphere_radius': 2.5
67
+ }
68
+
69
+ # Function call
70
+ _, measurements, measurements_max = measure_chrom2(path, hei10, config)
71
+
72
+ # Assertions
73
+ assert len(measurements) == len(path), "Measurements length should match path length"
74
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
75
+ assert all(measurements == np.array([4,6]))
76
+ assert all(measurements_max == np.array([6,8]))
77
+
78
+ def test_measure_chrom2_z2():
79
+ # Mock data
80
+ path = [(0,0,0), (2, 3, 4), (4, 5, 6)] # Sample ordered path points
81
+ _,_,hei10 = np.meshgrid(np.arange(10), np.arange(10), np.arange(10)) # 3D fluorescence data - z dependent
82
+ config = {
83
+ 'z_res': 0.25,
84
+ 'xy_res': 0.5,
85
+ 'sphere_radius': 2.5
86
+ }
87
+
88
+ # Function call
89
+ _, measurements, measurements_max = measure_chrom2(path, hei10, config)
90
+
91
+ # Assertions
92
+ assert len(measurements) == len(path), "Measurements length should match path length"
93
+ assert len(measurements_max) == len(path), "Max measurements length should match path length"
94
+ assert all(measurements_max == np.array([9,9,9]))
95
+
96
+
97
+ def test_measure_from_mask():
98
+ mask = np.array([
99
+ [0, 1, 0],
100
+ [1, 1, 1],
101
+ [0, 1, 0]
102
+ ])
103
+ measure_stack = np.array([
104
+ [2, 4, 2],
105
+ [4, 8, 4],
106
+ [2, 4, 2]
107
+ ])
108
+ result = measure_from_mask(mask, measure_stack)
109
+ assert result == 24 # Expected sum: 4+4+8+4+4
110
+
111
+ def test_max_from_mask():
112
+ mask = np.array([
113
+ [0, 1, 0],
114
+ [1, 1, 1],
115
+ [0, 1, 0]
116
+ ])
117
+ measure_stack = np.array([
118
+ [2, 5, 2],
119
+ [4, 8, 3],
120
+ [2, 7, 2]
121
+ ])
122
+ result = max_from_mask(mask, measure_stack)
123
+ assert result == 8 # Expected max: 8
124
+
125
+
126
+ def test_measure_at_point_mean():
127
+ measure_stack = np.array([
128
+ [[2, 2, 2, 0], [4, 4, 6, 0], [3, 3, 2, 0], [0, 0, 0, 0]],
129
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
130
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
131
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
132
+ ])
133
+ p = (1, 1, 1)
134
+ melem = np.ones((3, 3, 3))
135
+ result = measure_at_point(p, melem, measure_stack, op='mean')
136
+ assert result == 4, "Expected mean: 4"
137
+
138
+ def test_measure_at_point_mean_off1():
139
+ measure_stack = np.array([
140
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
141
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
142
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
143
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
144
+ ])
145
+ p = (0, 0, 0)
146
+ melem = np.ones((3, 3, 3))
147
+ result = measure_at_point(p, melem, measure_stack, op='mean')
148
+ assert result == 4.5, "Expected mean: 4.5"
149
+
150
+ def test_measure_at_point_mean_off2():
151
+ measure_stack = np.array([
152
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
153
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
154
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
155
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
156
+ ])
157
+ p = (3, 1, 1)
158
+ melem = np.ones((3, 3, 3))
159
+ print(measure_stack[p[0], p[1], p[2]])
160
+
161
+ result = measure_at_point(p, melem, measure_stack, op='mean')
162
+ assert result == 32/18 # Expected mean: 4.5
163
+
164
+ def test_measure_at_point_mean_off3():
165
+ measure_stack = np.array([
166
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
167
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
168
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
169
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
170
+ ])
171
+ p = (3, 1, 1)
172
+ melem = np.ones((1, 1, 3))
173
+ print(measure_stack[p[0], p[1], p[2]])
174
+
175
+ result = measure_at_point(p, melem, measure_stack, op='mean')
176
+ assert result == 0, "Expected mean: 4.5"
177
+
178
+ def test_measure_at_point_mean_off3():
179
+ measure_stack = np.array([
180
+ [[2, 2, 2, 0], [4, 4, 6, 0], [5, 5, 2, 0], [0, 0, 0, 0]],
181
+ [[4, 4, 4, 0], [8, 8, 8, 0], [4, 4, 4, 0], [0, 0, 0, 0]],
182
+ [[3, 3, 3, 0], [6, 6, 4, 0], [3, 2, 2, 0], [0, 0, 0, 0]],
183
+ [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
184
+ ])
185
+ p = (3, 1, 1)
186
+ melem = np.ones((3, 1, 1))
187
+ print(measure_stack[p[0], p[1], p[2]])
188
+
189
+ result = measure_at_point(p, melem, measure_stack, op='mean')
190
+ assert result == 3, "Expected mean: 4.5"
191
+
192
+
193
+ def test_measure_at_point_max():
194
+ measure_stack = np.array([
195
+ [[2, 2, 2], [4, 4, 4], [2, 2, 2]],
196
+ [[4, 5, 4], [8, 7, 9], [4, 4, 4]],
197
+ [[2, 2, 2], [4, 4, 4], [2, 2, 2]]
198
+ ])
199
+ p = (1, 1, 1)
200
+ melem = np.ones((3, 3, 3))
201
+ result = measure_at_point(p, melem, measure_stack, op='max')
202
+ assert result == 9, "Expected max: 9"
203
+
204
+
205
+ def test_make_sphere_equal():
206
+ R = 5
207
+ z_scale_ratio = 1.0
208
+
209
+ sphere = make_sphere(R, z_scale_ratio)
210
+
211
+ # Check the returned type
212
+ assert isinstance(sphere, np.ndarray), "Output should be a numpy ndarray"
213
+
214
+ # Check the shape
215
+ expected_shape = (2*R+1, 2*R+1, 2*R+1)
216
+ assert sphere.shape == expected_shape, f"Expected shape {expected_shape}, but got {sphere.shape}"
217
+
218
+ assert (sphere[:,:,::-1] == sphere).all(), f"Expected symmetrical mask"
219
+ assert (sphere[:,::-1,:] == sphere).all(), f"Expected symmetrical mask"
220
+ assert (sphere[::-1,:,:] == sphere).all(), f"Expected symmetrical mask"
221
+ assert abs(np.sum(sphere)-4/3*pi*R**3)<10, f"Expected approximate volume to be correct"
222
+ assert (sphere[R,R,0] == 1), f"Expected centre point on top plane to be within sphere"
223
+ assert (sphere[R+1,R,0] == 0), f"Expected point next to centre on top plane to be outside sphere"
tests/test_preprocess.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from path_analysis.data_preprocess import *
2
+ import numpy as np
3
+ import pytest
4
+
5
+ def test_thin_points():
6
+ # Define a sample point list
7
+ points = [
8
+ ([0, 0, 0], 10, 0),
9
+ ([1, 1, 1], 8, 1),
10
+ ([10, 10, 10], 12, 2),
11
+ ([10.5, 10.5, 10.5], 5, 3),
12
+ ([20, 20, 20], 15, 4)
13
+ ]
14
+
15
+ # Call the thin_points function with dmin=5 (for example)
16
+ removed_indices = thin_points(points, dmin=5)
17
+
18
+ # Check results
19
+ # Point at index 1 ([1, 1, 1]) should be removed since it's within 5 units distance of point at index 0 and has lower intensity.
20
+ # Similarly, point at index 3 ([10.5, 10.5, 10.5]) should be removed as it's close to point at index 2 and has lower intensity.
21
+ assert set(removed_indices) == {1, 3}
22
+
23
+ # Another simple test to check if function does nothing when points are far apart
24
+ far_points = [
25
+ ([0, 0, 0], 10, 0),
26
+ ([100, 100, 100], 12, 1),
27
+ ([200, 200, 200], 15, 2)
28
+ ]
29
+
30
+ removed_indices_far = thin_points(far_points, dmin=5)
31
+ assert len(removed_indices_far) == 0 # Expect no points to be removed
32
+
33
+
34
+ def test_find_peaks2():
35
+
36
+ # Basic test
37
+ data = np.array([0, 0, 0, 0, 0, 0, 5, 0, 3, 0])
38
+ peaks, _ = find_peaks2(data)
39
+ assert set(peaks) == {6} # Expected peaks at positions 6
40
+
41
+ # Basic test
42
+ data = np.array([0, 2, 0, 0, 0, 0, 0, 0, 0, 0])
43
+ peaks, _ = find_peaks2(data)
44
+ assert set(peaks) == {1} # Expected peaks at positions 1
45
+
46
+
47
+ # Test with padding impacting peak detection
48
+ data = np.array([3, 2.9, 0, 0, 0, 3])
49
+ peaks, _ = find_peaks2(data)
50
+ assert set(peaks) == {0,5} # Peaks at both ends
51
+
52
+ # Test with close peaks
53
+ data = np.array([3, 0, 3])
54
+ peaks, _ = find_peaks2(data)
55
+ assert set(peaks) == {2} # Peak at right end only
56
+ # Test with close peaks
57
+
58
+
59
+ # Test with close peaks
60
+ data = np.array([3, 0, 3])
61
+ peaks, _ = find_peaks2(data, distance=1)
62
+ assert set(peaks) == {0,2} # Peaks at both ends
63
+
64
+ # Test with close peaks
65
+ data = np.array([0, 3, 3, 3, 0, 3, 3, 3, 3, 3, 3])
66
+ peaks, _ = find_peaks2(data, distance=1)
67
+ assert set(peaks) == {2,7} # Peak at centre (rounded to the left) of groups of maximum values
68
+
69
+ # Test with prominence threshold
70
+ data = np.array([0, 1, 0, 0.4, 0])
71
+ peaks, _ = find_peaks2(data, prominence=0.5)
72
+ assert peaks == [1] # Only the peak at position 1 meets the prominence threshold
73
+
74
+
75
+ def test_pc():
76
+ pos = np.array([0, 1, 2, 3, 4, 6])
77
+ values = np.array([0.1, 0.5, 0.2, 0.8, 0.3, 0.9])
78
+
79
+ # Basic test
80
+ assert np.array_equal(pc(pos, values), np.array([1, 3, 6])) # only values 0.8 and 0.9 exceed 0.4 times the max (which is 0.9)
81
+
82
+ # Test with custom alpha
83
+ assert np.array_equal(pc(pos, values, alpha=0.5), np.array([1, 3, 6]))
84
+
85
+ # Test with a larger alpha
86
+ assert np.array_equal(pc(pos, values, alpha=1.0), [6]) # No values exceed the maximum value itself
87
+
88
+ # Test with all values below threshold
89
+ values = np.array([0.1, 0.2, 0.3, 0.4])
90
+
91
+ assert np.array_equal(pc(pos[:4], values), [1,2,3]) # All values are below 0.4 times the max (which is 0.4)
92
+
93
+ @pytest.fixture
94
+ def mock_data():
95
+ all_paths = [ [ (0,0,0), (0,2,0), (0,5,0), (0,10,0), (0,15,0), (0,20,0)], [ (1,20,0), (1,20,10), (1,20,20) ] ] # Mock paths
96
+ path_lengths = [ 2.2, 2.3 ] # Mock path lengths
97
+ measured_trace_fluorescence = [ [100, 8, 3, 2, 3, 39], [38, 2, 20] ] # Mock fluorescence data
98
+ return all_paths, path_lengths, measured_trace_fluorescence
99
+
100
+ def test_process_cell_traces_return_type(mock_data):
101
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
102
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
103
+ assert isinstance(result, CellData), f"Expected CellData but got {type(result)}"
104
+
105
+ def test_process_cell_traces_pathdata_list_length(mock_data):
106
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
107
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
108
+ assert len(result.pathdata_list) == len(all_paths), f"Expected {len(all_paths)} but got {len(result.pathdata_list)}"
109
+
110
+ def test_process_cell_traces_pathdata_path_lengths(mock_data):
111
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
112
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
113
+ path_lengths = [p.SC_length for p in result.pathdata_list]
114
+ expected_path_lengths = [2.2, 2.3]
115
+ assert path_lengths == expected_path_lengths, f"Expected {expected_path_lengths} but got {path_lengths}"
116
+
117
+ def test_process_cell_traces_peaks(mock_data):
118
+ all_paths, path_lengths, measured_trace_fluorescence = mock_data
119
+ result = process_cell_traces(all_paths, path_lengths, measured_trace_fluorescence)
120
+ peaks = [p.peaks for p in result.pathdata_list]
121
+ assert peaks == [[0,5],[]]
122
+
123
+ # Mock data
124
+ @pytest.fixture
125
+ def mock_celldata():
126
+ pathdata1 = PathData(peaks=[0, 5], points=[(0,0,0), (0,2,0), (0,5,0), (0,10,0), (0,15,0), (0,20,0)], o_hei10=[100, 8, 3, 2, 3, 39], SC_length=2.2)
127
+ pathdata2 = PathData(peaks=[0], points=[(1,20,0), (1,20,10), (1,20,20) ], o_hei10=[38, 2, 20], SC_length=2.3)
128
+ return CellData(pathdata_list=[pathdata1, pathdata2])
129
+
130
+ def test_analyse_celldata_output_length(mock_celldata):
131
+ rel_intensity, pos, pos_index, trace_median_intensity, trace_thresholds = analyse_celldata(mock_celldata, {'peak_threshold': 0.4, 'threshold_type':'per-trace'})
132
+ assert len(rel_intensity) == len(mock_celldata.pathdata_list), "Mismatch in relative intensities length"
133
+ assert len(pos) == len(mock_celldata.pathdata_list), "Mismatch in positions length"
134
+ assert len(pos_index) == len(mock_celldata.pathdata_list), "Mismatch in position indices length"
135
+
136
+
137
+
138
+
139
+