porestar commited on
Commit
c2d0da9
·
1 Parent(s): 8d3f1a3

Update app.py, score_db.py, and requirements.txt

Browse files
Files changed (3) hide show
  1. app.py +144 -4
  2. requirements.txt +5 -0
  3. score_db.py +143 -0
app.py CHANGED
@@ -1,7 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ import base64
2
+ import io
3
+ import random
4
+ from io import BytesIO
5
+
6
+ import matplotlib
7
+ matplotlib.use('Agg')
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ from PIL import Image
11
+ import requests
12
+ from datasets import load_dataset
13
  import gradio as gr
14
 
15
+ from score_db import Battle
16
+ from score_db import Model as ModelEnum, Winner
17
+
18
+ def make_plot(seismic, predicted_image):
19
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
20
+ ax.imshow(Image.fromarray(seismic), cmap="gray")
21
+ ax.imshow(predicted_image, cmap="Reds", alpha=0.5, vmin=0, vmax=1)
22
+ ax.set_axis_off()
23
+ fig.canvas.draw()
24
+
25
+ # Create a bytes buffer to save the plot
26
+ buf = io.BytesIO()
27
+ plt.savefig(buf, format='png', bbox_inches='tight')
28
+ buf.seek(0)
29
+
30
+ # Open the PNG image from the buffer and convert it to a NumPy array
31
+ image = np.array(Image.open(buf))
32
+ return image
33
+
34
+ def call_endpoint(model: ModelEnum, img_array, url: str="https://lukasmosser--seisbase-endpoints-predict.modal.run"):
35
+ response = requests.post(url, json={"img": img_array.tolist(), "model": model})
36
+
37
+ if response:
38
+ # Parse the base64-encoded image data
39
+ if response.text.startswith("data:image/tiff;base64,"):
40
+ img_data_out = base64.b64decode(response.text.split(",")[1])
41
+ predicted_image = np.array(Image.open(BytesIO(img_data_out)))
42
+ return predicted_image
43
+
44
+ def select_random_image(dataset):
45
+ idx = random.randint(0, len(dataset))
46
+ return idx, np.array(dataset[idx]["seismic"])
47
+
48
+ def select_random_models():
49
+ model_a = random.choice(list(ModelEnum))
50
+ model_b = random.choice(list(ModelEnum))
51
+ return model_a, model_b
52
+
53
+
54
+ # Create a Gradio interface
55
+ with gr.Blocks() as evaluation:
56
+ gr.Markdown("""
57
+ ## Seismic Fault Detection Model Evaluation
58
+ This application allows you to compare the performance of different seismic fault detection models.
59
+ Two models are selected randomly, and their predictions are displayed side by side.
60
+ You can choose the better model or mark it as a tie. The results are recorded and used to update the model ratings.
61
+ """)
62
+
63
+ battle = gr.State([])
64
+ radio = gr.Radio(choices=["Less than 5 years", "5 to 20 years", "more than 20 years"], label="How much experience do you have in seismic fault interpretation?")
65
+ with gr.Row():
66
+ output_img1 = gr.Image(label="Model A Image")
67
+ output_img2 = gr.Image(label="Model B Image")
68
+
69
+ def show_images():
70
+ dataset = load_dataset("porestar/crossdomainfoundationmodeladaption-deepfault", split="valid")
71
+ idx, image_1 = select_random_image(dataset)
72
+ model_a, model_b = select_random_models()
73
+ fault_probability_1 = call_endpoint(model_a, image_1)
74
+ fault_probability_2 = call_endpoint(model_b, image_1)
75
+
76
+ img_1 = make_plot(image_1, fault_probability_1)
77
+ img_2 = make_plot(image_1, fault_probability_2)
78
+ experience = 1
79
+ if radio.value == "5 to 20 years":
80
+ experience = 2
81
+ elif radio.value == "more than 20 years":
82
+ experience = 3
83
+ battle.value.append(Battle(model_a=model_a, model_b=model_b, winner="tie", judge="None", experience=experience, image_idx=idx))
84
+ return img_1, img_2
85
+
86
+ # Define the function to make an API call
87
+ def make_api_call(choice: Winner):
88
+ api_url = "https://lukasmosser--seisbase-eval-add-battle.modal.run"
89
+ battle_out = battle.value
90
+ battle_out[-1].winner = choice
91
+ experience = 1
92
+ if radio.value == "5 to 20 years":
93
+ experience = 2
94
+ elif radio.value == "more than 20 years":
95
+ experience = 3
96
+ battle_out[-1].experience = experience
97
+ response = requests.post(api_url, json=battle_out[-1].dict())
98
+
99
+ # Load images on startup
100
+ evaluation.load(show_images, inputs=[], outputs=[output_img1, output_img2])
101
+
102
+ with gr.Row():
103
+ btn_winner_a = gr.Button("Winner Model A")
104
+ btn_tie = gr.Button("Tie")
105
+ btn_winner_b = gr.Button("Winner Model B")
106
+
107
+ # Define button click events
108
+ btn_winner_a.click(lambda: make_api_call(Winner.model_a), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
109
+ btn_tie.click(lambda: make_api_call(Winner.tie), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
110
+ btn_winner_b.click(lambda: make_api_call(Winner.model_b), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
111
+
112
+ with gr.Blocks() as leaderboard:
113
+ def get_results():
114
+ response = requests.get("https://lukasmosser--seisbase-eval-compute-ratings.modal.run")
115
+ data = response.json()
116
+
117
+ models = [entry["model"] for entry in data]
118
+ elo_ratings = [entry["elo_rating"] for entry in data]
119
+
120
+ fig, ax = plt.subplots()
121
+ ax.barh(models, elo_ratings, color='skyblue')
122
+ ax.set_xlabel('ELO Rating')
123
+ ax.set_title('Model ELO Ratings')
124
+ plt.tight_layout()
125
+
126
+ fig.canvas.draw()
127
+
128
+ # Create a bytes buffer to save the plot
129
+ buf = io.BytesIO()
130
+ plt.savefig(buf, format='png', bbox_inches='tight')
131
+ buf.seek(0)
132
+
133
+ # Open the PNG image from the buffer and convert it to a NumPy array
134
+ image = np.array(Image.open(buf))
135
+ return image
136
+
137
+ with gr.Row():
138
+ elo_ratings = gr.Image(label="ELO Ratings")
139
+
140
+ leaderboard.load(get_results, inputs=[], outputs=[elo_ratings])
141
+
142
+ demo = gr.TabbedInterface([evaluation, leaderboard], ["Arena", "Leaderboard"])
143
+
144
+ # Launch the interface
145
+ if __name__ == "__main__":
146
+ demo.launch(show_error=True)
147
 
 
 
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ matplotlib
2
+ numpy
3
+ gradio
4
+ datasets
5
+ requests
score_db.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import io
3
+ import json
4
+ import os
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import List
9
+
10
+ import pandas as pd
11
+ from fastapi import Response
12
+ from modal import web_endpoint
13
+ import modal
14
+ from pydantic import BaseModel
15
+
16
+ from rating import compute_mle_elo
17
+
18
+ # -----------------------
19
+ # Data Model Definition
20
+ # -----------------------
21
+ class ExperienceEnum(int, Enum):
22
+ novice = 1
23
+ intermediate = 2
24
+ expert = 3
25
+
26
+ class Winner(str, Enum):
27
+ model_a = "model_a"
28
+ model_b = "model_b"
29
+ tie = "tie"
30
+
31
+ class Model(str, Enum):
32
+ porestar_deepfault_unet_baseline_1 = "porestar/deepfault-unet-baseline-1"
33
+ porestar_deepfault_unet_baseline_2 = "porestar/deepfault-unet-baseline-2"
34
+
35
+ class Battle(BaseModel):
36
+ model_a: Model
37
+ model_b: Model
38
+ winner: Winner
39
+ judge: str
40
+ image_idx: int
41
+ experience: ExperienceEnum = ExperienceEnum.novice
42
+ tstamp: str = str(datetime.now())
43
+
44
+ class EloRating(BaseModel):
45
+ model: Model
46
+ elo_rating: float
47
+
48
+ # -----------------------
49
+ # Modal Configuration
50
+ # -----------------------
51
+
52
+ # Create a volume to persist data
53
+ data_volume = modal.Volume.from_name("seisbase-data", create_if_missing=True)
54
+
55
+ JSON_FILE_PATH = Path("/data/battles.json")
56
+ RESULTS_FILE_PATH = Path("/data/ratings.csv")
57
+
58
+ app_image = modal.Image.debian_slim(python_version="3.10").pip_install("pandas", "scikit-learn", "tqdm", "sympy")
59
+
60
+ app = modal.App(
61
+ image=app_image,
62
+ name="seisbase-eval",
63
+ volumes={"/data": data_volume},
64
+ )
65
+
66
+ def ensure_json_file():
67
+ """Ensure the JSON file exists and is initialized with an empty array if necessary."""
68
+ if not os.path.exists(JSON_FILE_PATH):
69
+ JSON_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
70
+ with open(JSON_FILE_PATH, "w") as f:
71
+ json.dump([], f)
72
+
73
+ def append_to_json_file(data):
74
+ """Append data to the JSON file."""
75
+ ensure_json_file()
76
+ try:
77
+ with open(JSON_FILE_PATH, "r+") as f:
78
+ try:
79
+ battles = json.load(f)
80
+ except json.JSONDecodeError:
81
+ # Reset the file if corrupted
82
+ battles = []
83
+ battles.append(data)
84
+ f.seek(0)
85
+ json.dump(battles, f, indent=4)
86
+ f.truncate()
87
+ except Exception as e:
88
+ raise RuntimeError(f"Failed to append data to JSON file: {e}")
89
+
90
+ def read_json_file():
91
+ """Read data from the JSON file."""
92
+ ensure_json_file()
93
+ try:
94
+ with open(JSON_FILE_PATH, "r") as f:
95
+ try:
96
+ return json.load(f)
97
+ except json.JSONDecodeError:
98
+ return [] # Return an empty list if the file is corrupted
99
+ except Exception as e:
100
+ raise RuntimeError(f"Failed to read JSON file: {e}")
101
+
102
+ @app.function()
103
+ @web_endpoint(method="POST", docs=True)
104
+ def add_battle(battle: Battle):
105
+ """Add a new battle to the JSON file."""
106
+ append_to_json_file(battle.dict())
107
+ return {"status": "success", "battle": battle.dict()}
108
+
109
+
110
+ @app.function()
111
+ @web_endpoint(method="GET", docs=True)
112
+ def export_csv():
113
+ """Fetch all battles and return as CSV."""
114
+ battles = read_json_file()
115
+
116
+ # Create CSV in memory
117
+ output = io.StringIO()
118
+ writer = csv.DictWriter(output, fieldnames=["model_a", "model_b", "winner", "judge", "imaged_idx", "experience", "tstamp"])
119
+ writer.writeheader()
120
+ writer.writerows(battles)
121
+
122
+ csv_data = output.getvalue()
123
+ return Response(content=csv_data, media_type="text/csv")
124
+
125
+ @app.function()
126
+ @web_endpoint(method="GET", docs=True)
127
+ def compute_ratings() -> List[EloRating]:
128
+ """Compute ratings from battles."""
129
+ battles = pd.read_json(JSON_FILE_PATH, dtype=[str, str, str, str, int, int, str]).sort_values(ascending=True, by=["tstamp"]).reset_index(drop=True)
130
+ elo_mle_ratings = compute_mle_elo(battles)
131
+ elo_mle_ratings.to_csv(RESULTS_FILE_PATH)
132
+
133
+ df = pd.read_csv(RESULTS_FILE_PATH)
134
+ df.columns = ["Model", "Elo rating"]
135
+ df = df.sort_values("Elo rating", ascending=False).reset_index(drop=True)
136
+ scores = []
137
+ for i in range(len(df)):
138
+ scores.append(EloRating(model=df["Model"][i], elo_rating=df["Elo rating"][i]))
139
+ return scores
140
+
141
+ @app.local_entrypoint()
142
+ def main():
143
+ print("Local entrypoint running. Check endpoints for functionality.")