Spaces:
Running
Running
Update app.py, score_db.py, and requirements.txt
Browse files- app.py +144 -4
- requirements.txt +5 -0
- score_db.py +143 -0
app.py
CHANGED
@@ -1,7 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import random
|
4 |
+
from io import BytesIO
|
5 |
+
|
6 |
+
import matplotlib
|
7 |
+
matplotlib.use('Agg')
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import requests
|
12 |
+
from datasets import load_dataset
|
13 |
import gradio as gr
|
14 |
|
15 |
+
from score_db import Battle
|
16 |
+
from score_db import Model as ModelEnum, Winner
|
17 |
+
|
18 |
+
def make_plot(seismic, predicted_image):
|
19 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
20 |
+
ax.imshow(Image.fromarray(seismic), cmap="gray")
|
21 |
+
ax.imshow(predicted_image, cmap="Reds", alpha=0.5, vmin=0, vmax=1)
|
22 |
+
ax.set_axis_off()
|
23 |
+
fig.canvas.draw()
|
24 |
+
|
25 |
+
# Create a bytes buffer to save the plot
|
26 |
+
buf = io.BytesIO()
|
27 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
28 |
+
buf.seek(0)
|
29 |
+
|
30 |
+
# Open the PNG image from the buffer and convert it to a NumPy array
|
31 |
+
image = np.array(Image.open(buf))
|
32 |
+
return image
|
33 |
+
|
34 |
+
def call_endpoint(model: ModelEnum, img_array, url: str="https://lukasmosser--seisbase-endpoints-predict.modal.run"):
|
35 |
+
response = requests.post(url, json={"img": img_array.tolist(), "model": model})
|
36 |
+
|
37 |
+
if response:
|
38 |
+
# Parse the base64-encoded image data
|
39 |
+
if response.text.startswith("data:image/tiff;base64,"):
|
40 |
+
img_data_out = base64.b64decode(response.text.split(",")[1])
|
41 |
+
predicted_image = np.array(Image.open(BytesIO(img_data_out)))
|
42 |
+
return predicted_image
|
43 |
+
|
44 |
+
def select_random_image(dataset):
|
45 |
+
idx = random.randint(0, len(dataset))
|
46 |
+
return idx, np.array(dataset[idx]["seismic"])
|
47 |
+
|
48 |
+
def select_random_models():
|
49 |
+
model_a = random.choice(list(ModelEnum))
|
50 |
+
model_b = random.choice(list(ModelEnum))
|
51 |
+
return model_a, model_b
|
52 |
+
|
53 |
+
|
54 |
+
# Create a Gradio interface
|
55 |
+
with gr.Blocks() as evaluation:
|
56 |
+
gr.Markdown("""
|
57 |
+
## Seismic Fault Detection Model Evaluation
|
58 |
+
This application allows you to compare the performance of different seismic fault detection models.
|
59 |
+
Two models are selected randomly, and their predictions are displayed side by side.
|
60 |
+
You can choose the better model or mark it as a tie. The results are recorded and used to update the model ratings.
|
61 |
+
""")
|
62 |
+
|
63 |
+
battle = gr.State([])
|
64 |
+
radio = gr.Radio(choices=["Less than 5 years", "5 to 20 years", "more than 20 years"], label="How much experience do you have in seismic fault interpretation?")
|
65 |
+
with gr.Row():
|
66 |
+
output_img1 = gr.Image(label="Model A Image")
|
67 |
+
output_img2 = gr.Image(label="Model B Image")
|
68 |
+
|
69 |
+
def show_images():
|
70 |
+
dataset = load_dataset("porestar/crossdomainfoundationmodeladaption-deepfault", split="valid")
|
71 |
+
idx, image_1 = select_random_image(dataset)
|
72 |
+
model_a, model_b = select_random_models()
|
73 |
+
fault_probability_1 = call_endpoint(model_a, image_1)
|
74 |
+
fault_probability_2 = call_endpoint(model_b, image_1)
|
75 |
+
|
76 |
+
img_1 = make_plot(image_1, fault_probability_1)
|
77 |
+
img_2 = make_plot(image_1, fault_probability_2)
|
78 |
+
experience = 1
|
79 |
+
if radio.value == "5 to 20 years":
|
80 |
+
experience = 2
|
81 |
+
elif radio.value == "more than 20 years":
|
82 |
+
experience = 3
|
83 |
+
battle.value.append(Battle(model_a=model_a, model_b=model_b, winner="tie", judge="None", experience=experience, image_idx=idx))
|
84 |
+
return img_1, img_2
|
85 |
+
|
86 |
+
# Define the function to make an API call
|
87 |
+
def make_api_call(choice: Winner):
|
88 |
+
api_url = "https://lukasmosser--seisbase-eval-add-battle.modal.run"
|
89 |
+
battle_out = battle.value
|
90 |
+
battle_out[-1].winner = choice
|
91 |
+
experience = 1
|
92 |
+
if radio.value == "5 to 20 years":
|
93 |
+
experience = 2
|
94 |
+
elif radio.value == "more than 20 years":
|
95 |
+
experience = 3
|
96 |
+
battle_out[-1].experience = experience
|
97 |
+
response = requests.post(api_url, json=battle_out[-1].dict())
|
98 |
+
|
99 |
+
# Load images on startup
|
100 |
+
evaluation.load(show_images, inputs=[], outputs=[output_img1, output_img2])
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
btn_winner_a = gr.Button("Winner Model A")
|
104 |
+
btn_tie = gr.Button("Tie")
|
105 |
+
btn_winner_b = gr.Button("Winner Model B")
|
106 |
+
|
107 |
+
# Define button click events
|
108 |
+
btn_winner_a.click(lambda: make_api_call(Winner.model_a), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
109 |
+
btn_tie.click(lambda: make_api_call(Winner.tie), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
110 |
+
btn_winner_b.click(lambda: make_api_call(Winner.model_b), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
111 |
+
|
112 |
+
with gr.Blocks() as leaderboard:
|
113 |
+
def get_results():
|
114 |
+
response = requests.get("https://lukasmosser--seisbase-eval-compute-ratings.modal.run")
|
115 |
+
data = response.json()
|
116 |
+
|
117 |
+
models = [entry["model"] for entry in data]
|
118 |
+
elo_ratings = [entry["elo_rating"] for entry in data]
|
119 |
+
|
120 |
+
fig, ax = plt.subplots()
|
121 |
+
ax.barh(models, elo_ratings, color='skyblue')
|
122 |
+
ax.set_xlabel('ELO Rating')
|
123 |
+
ax.set_title('Model ELO Ratings')
|
124 |
+
plt.tight_layout()
|
125 |
+
|
126 |
+
fig.canvas.draw()
|
127 |
+
|
128 |
+
# Create a bytes buffer to save the plot
|
129 |
+
buf = io.BytesIO()
|
130 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
131 |
+
buf.seek(0)
|
132 |
+
|
133 |
+
# Open the PNG image from the buffer and convert it to a NumPy array
|
134 |
+
image = np.array(Image.open(buf))
|
135 |
+
return image
|
136 |
+
|
137 |
+
with gr.Row():
|
138 |
+
elo_ratings = gr.Image(label="ELO Ratings")
|
139 |
+
|
140 |
+
leaderboard.load(get_results, inputs=[], outputs=[elo_ratings])
|
141 |
+
|
142 |
+
demo = gr.TabbedInterface([evaluation, leaderboard], ["Arena", "Leaderboard"])
|
143 |
+
|
144 |
+
# Launch the interface
|
145 |
+
if __name__ == "__main__":
|
146 |
+
demo.launch(show_error=True)
|
147 |
|
|
|
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
matplotlib
|
2 |
+
numpy
|
3 |
+
gradio
|
4 |
+
datasets
|
5 |
+
requests
|
score_db.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import io
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from datetime import datetime
|
6 |
+
from enum import Enum
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
from fastapi import Response
|
12 |
+
from modal import web_endpoint
|
13 |
+
import modal
|
14 |
+
from pydantic import BaseModel
|
15 |
+
|
16 |
+
from rating import compute_mle_elo
|
17 |
+
|
18 |
+
# -----------------------
|
19 |
+
# Data Model Definition
|
20 |
+
# -----------------------
|
21 |
+
class ExperienceEnum(int, Enum):
|
22 |
+
novice = 1
|
23 |
+
intermediate = 2
|
24 |
+
expert = 3
|
25 |
+
|
26 |
+
class Winner(str, Enum):
|
27 |
+
model_a = "model_a"
|
28 |
+
model_b = "model_b"
|
29 |
+
tie = "tie"
|
30 |
+
|
31 |
+
class Model(str, Enum):
|
32 |
+
porestar_deepfault_unet_baseline_1 = "porestar/deepfault-unet-baseline-1"
|
33 |
+
porestar_deepfault_unet_baseline_2 = "porestar/deepfault-unet-baseline-2"
|
34 |
+
|
35 |
+
class Battle(BaseModel):
|
36 |
+
model_a: Model
|
37 |
+
model_b: Model
|
38 |
+
winner: Winner
|
39 |
+
judge: str
|
40 |
+
image_idx: int
|
41 |
+
experience: ExperienceEnum = ExperienceEnum.novice
|
42 |
+
tstamp: str = str(datetime.now())
|
43 |
+
|
44 |
+
class EloRating(BaseModel):
|
45 |
+
model: Model
|
46 |
+
elo_rating: float
|
47 |
+
|
48 |
+
# -----------------------
|
49 |
+
# Modal Configuration
|
50 |
+
# -----------------------
|
51 |
+
|
52 |
+
# Create a volume to persist data
|
53 |
+
data_volume = modal.Volume.from_name("seisbase-data", create_if_missing=True)
|
54 |
+
|
55 |
+
JSON_FILE_PATH = Path("/data/battles.json")
|
56 |
+
RESULTS_FILE_PATH = Path("/data/ratings.csv")
|
57 |
+
|
58 |
+
app_image = modal.Image.debian_slim(python_version="3.10").pip_install("pandas", "scikit-learn", "tqdm", "sympy")
|
59 |
+
|
60 |
+
app = modal.App(
|
61 |
+
image=app_image,
|
62 |
+
name="seisbase-eval",
|
63 |
+
volumes={"/data": data_volume},
|
64 |
+
)
|
65 |
+
|
66 |
+
def ensure_json_file():
|
67 |
+
"""Ensure the JSON file exists and is initialized with an empty array if necessary."""
|
68 |
+
if not os.path.exists(JSON_FILE_PATH):
|
69 |
+
JSON_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
70 |
+
with open(JSON_FILE_PATH, "w") as f:
|
71 |
+
json.dump([], f)
|
72 |
+
|
73 |
+
def append_to_json_file(data):
|
74 |
+
"""Append data to the JSON file."""
|
75 |
+
ensure_json_file()
|
76 |
+
try:
|
77 |
+
with open(JSON_FILE_PATH, "r+") as f:
|
78 |
+
try:
|
79 |
+
battles = json.load(f)
|
80 |
+
except json.JSONDecodeError:
|
81 |
+
# Reset the file if corrupted
|
82 |
+
battles = []
|
83 |
+
battles.append(data)
|
84 |
+
f.seek(0)
|
85 |
+
json.dump(battles, f, indent=4)
|
86 |
+
f.truncate()
|
87 |
+
except Exception as e:
|
88 |
+
raise RuntimeError(f"Failed to append data to JSON file: {e}")
|
89 |
+
|
90 |
+
def read_json_file():
|
91 |
+
"""Read data from the JSON file."""
|
92 |
+
ensure_json_file()
|
93 |
+
try:
|
94 |
+
with open(JSON_FILE_PATH, "r") as f:
|
95 |
+
try:
|
96 |
+
return json.load(f)
|
97 |
+
except json.JSONDecodeError:
|
98 |
+
return [] # Return an empty list if the file is corrupted
|
99 |
+
except Exception as e:
|
100 |
+
raise RuntimeError(f"Failed to read JSON file: {e}")
|
101 |
+
|
102 |
+
@app.function()
|
103 |
+
@web_endpoint(method="POST", docs=True)
|
104 |
+
def add_battle(battle: Battle):
|
105 |
+
"""Add a new battle to the JSON file."""
|
106 |
+
append_to_json_file(battle.dict())
|
107 |
+
return {"status": "success", "battle": battle.dict()}
|
108 |
+
|
109 |
+
|
110 |
+
@app.function()
|
111 |
+
@web_endpoint(method="GET", docs=True)
|
112 |
+
def export_csv():
|
113 |
+
"""Fetch all battles and return as CSV."""
|
114 |
+
battles = read_json_file()
|
115 |
+
|
116 |
+
# Create CSV in memory
|
117 |
+
output = io.StringIO()
|
118 |
+
writer = csv.DictWriter(output, fieldnames=["model_a", "model_b", "winner", "judge", "imaged_idx", "experience", "tstamp"])
|
119 |
+
writer.writeheader()
|
120 |
+
writer.writerows(battles)
|
121 |
+
|
122 |
+
csv_data = output.getvalue()
|
123 |
+
return Response(content=csv_data, media_type="text/csv")
|
124 |
+
|
125 |
+
@app.function()
|
126 |
+
@web_endpoint(method="GET", docs=True)
|
127 |
+
def compute_ratings() -> List[EloRating]:
|
128 |
+
"""Compute ratings from battles."""
|
129 |
+
battles = pd.read_json(JSON_FILE_PATH, dtype=[str, str, str, str, int, int, str]).sort_values(ascending=True, by=["tstamp"]).reset_index(drop=True)
|
130 |
+
elo_mle_ratings = compute_mle_elo(battles)
|
131 |
+
elo_mle_ratings.to_csv(RESULTS_FILE_PATH)
|
132 |
+
|
133 |
+
df = pd.read_csv(RESULTS_FILE_PATH)
|
134 |
+
df.columns = ["Model", "Elo rating"]
|
135 |
+
df = df.sort_values("Elo rating", ascending=False).reset_index(drop=True)
|
136 |
+
scores = []
|
137 |
+
for i in range(len(df)):
|
138 |
+
scores.append(EloRating(model=df["Model"][i], elo_rating=df["Elo rating"][i]))
|
139 |
+
return scores
|
140 |
+
|
141 |
+
@app.local_entrypoint()
|
142 |
+
def main():
|
143 |
+
print("Local entrypoint running. Check endpoints for functionality.")
|