Spaces:
Build error
Build error
Update app.py, score_db.py, and requirements.txt
Browse files- app.py +144 -4
- requirements.txt +5 -0
- score_db.py +143 -0
app.py
CHANGED
|
@@ -1,7 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 7 |
-
demo.launch()
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import random
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
import matplotlib
|
| 7 |
+
matplotlib.use('Agg')
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import numpy as np
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import requests
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
+
from score_db import Battle
|
| 16 |
+
from score_db import Model as ModelEnum, Winner
|
| 17 |
+
|
| 18 |
+
def make_plot(seismic, predicted_image):
|
| 19 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
| 20 |
+
ax.imshow(Image.fromarray(seismic), cmap="gray")
|
| 21 |
+
ax.imshow(predicted_image, cmap="Reds", alpha=0.5, vmin=0, vmax=1)
|
| 22 |
+
ax.set_axis_off()
|
| 23 |
+
fig.canvas.draw()
|
| 24 |
+
|
| 25 |
+
# Create a bytes buffer to save the plot
|
| 26 |
+
buf = io.BytesIO()
|
| 27 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 28 |
+
buf.seek(0)
|
| 29 |
+
|
| 30 |
+
# Open the PNG image from the buffer and convert it to a NumPy array
|
| 31 |
+
image = np.array(Image.open(buf))
|
| 32 |
+
return image
|
| 33 |
+
|
| 34 |
+
def call_endpoint(model: ModelEnum, img_array, url: str="https://lukasmosser--seisbase-endpoints-predict.modal.run"):
|
| 35 |
+
response = requests.post(url, json={"img": img_array.tolist(), "model": model})
|
| 36 |
+
|
| 37 |
+
if response:
|
| 38 |
+
# Parse the base64-encoded image data
|
| 39 |
+
if response.text.startswith("data:image/tiff;base64,"):
|
| 40 |
+
img_data_out = base64.b64decode(response.text.split(",")[1])
|
| 41 |
+
predicted_image = np.array(Image.open(BytesIO(img_data_out)))
|
| 42 |
+
return predicted_image
|
| 43 |
+
|
| 44 |
+
def select_random_image(dataset):
|
| 45 |
+
idx = random.randint(0, len(dataset))
|
| 46 |
+
return idx, np.array(dataset[idx]["seismic"])
|
| 47 |
+
|
| 48 |
+
def select_random_models():
|
| 49 |
+
model_a = random.choice(list(ModelEnum))
|
| 50 |
+
model_b = random.choice(list(ModelEnum))
|
| 51 |
+
return model_a, model_b
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Create a Gradio interface
|
| 55 |
+
with gr.Blocks() as evaluation:
|
| 56 |
+
gr.Markdown("""
|
| 57 |
+
## Seismic Fault Detection Model Evaluation
|
| 58 |
+
This application allows you to compare the performance of different seismic fault detection models.
|
| 59 |
+
Two models are selected randomly, and their predictions are displayed side by side.
|
| 60 |
+
You can choose the better model or mark it as a tie. The results are recorded and used to update the model ratings.
|
| 61 |
+
""")
|
| 62 |
+
|
| 63 |
+
battle = gr.State([])
|
| 64 |
+
radio = gr.Radio(choices=["Less than 5 years", "5 to 20 years", "more than 20 years"], label="How much experience do you have in seismic fault interpretation?")
|
| 65 |
+
with gr.Row():
|
| 66 |
+
output_img1 = gr.Image(label="Model A Image")
|
| 67 |
+
output_img2 = gr.Image(label="Model B Image")
|
| 68 |
+
|
| 69 |
+
def show_images():
|
| 70 |
+
dataset = load_dataset("porestar/crossdomainfoundationmodeladaption-deepfault", split="valid")
|
| 71 |
+
idx, image_1 = select_random_image(dataset)
|
| 72 |
+
model_a, model_b = select_random_models()
|
| 73 |
+
fault_probability_1 = call_endpoint(model_a, image_1)
|
| 74 |
+
fault_probability_2 = call_endpoint(model_b, image_1)
|
| 75 |
+
|
| 76 |
+
img_1 = make_plot(image_1, fault_probability_1)
|
| 77 |
+
img_2 = make_plot(image_1, fault_probability_2)
|
| 78 |
+
experience = 1
|
| 79 |
+
if radio.value == "5 to 20 years":
|
| 80 |
+
experience = 2
|
| 81 |
+
elif radio.value == "more than 20 years":
|
| 82 |
+
experience = 3
|
| 83 |
+
battle.value.append(Battle(model_a=model_a, model_b=model_b, winner="tie", judge="None", experience=experience, image_idx=idx))
|
| 84 |
+
return img_1, img_2
|
| 85 |
+
|
| 86 |
+
# Define the function to make an API call
|
| 87 |
+
def make_api_call(choice: Winner):
|
| 88 |
+
api_url = "https://lukasmosser--seisbase-eval-add-battle.modal.run"
|
| 89 |
+
battle_out = battle.value
|
| 90 |
+
battle_out[-1].winner = choice
|
| 91 |
+
experience = 1
|
| 92 |
+
if radio.value == "5 to 20 years":
|
| 93 |
+
experience = 2
|
| 94 |
+
elif radio.value == "more than 20 years":
|
| 95 |
+
experience = 3
|
| 96 |
+
battle_out[-1].experience = experience
|
| 97 |
+
response = requests.post(api_url, json=battle_out[-1].dict())
|
| 98 |
+
|
| 99 |
+
# Load images on startup
|
| 100 |
+
evaluation.load(show_images, inputs=[], outputs=[output_img1, output_img2])
|
| 101 |
+
|
| 102 |
+
with gr.Row():
|
| 103 |
+
btn_winner_a = gr.Button("Winner Model A")
|
| 104 |
+
btn_tie = gr.Button("Tie")
|
| 105 |
+
btn_winner_b = gr.Button("Winner Model B")
|
| 106 |
+
|
| 107 |
+
# Define button click events
|
| 108 |
+
btn_winner_a.click(lambda: make_api_call(Winner.model_a), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
| 109 |
+
btn_tie.click(lambda: make_api_call(Winner.tie), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
| 110 |
+
btn_winner_b.click(lambda: make_api_call(Winner.model_b), inputs=[], outputs=[]).then(show_images, inputs=[], outputs=[output_img1, output_img2])
|
| 111 |
+
|
| 112 |
+
with gr.Blocks() as leaderboard:
|
| 113 |
+
def get_results():
|
| 114 |
+
response = requests.get("https://lukasmosser--seisbase-eval-compute-ratings.modal.run")
|
| 115 |
+
data = response.json()
|
| 116 |
+
|
| 117 |
+
models = [entry["model"] for entry in data]
|
| 118 |
+
elo_ratings = [entry["elo_rating"] for entry in data]
|
| 119 |
+
|
| 120 |
+
fig, ax = plt.subplots()
|
| 121 |
+
ax.barh(models, elo_ratings, color='skyblue')
|
| 122 |
+
ax.set_xlabel('ELO Rating')
|
| 123 |
+
ax.set_title('Model ELO Ratings')
|
| 124 |
+
plt.tight_layout()
|
| 125 |
+
|
| 126 |
+
fig.canvas.draw()
|
| 127 |
+
|
| 128 |
+
# Create a bytes buffer to save the plot
|
| 129 |
+
buf = io.BytesIO()
|
| 130 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
| 131 |
+
buf.seek(0)
|
| 132 |
+
|
| 133 |
+
# Open the PNG image from the buffer and convert it to a NumPy array
|
| 134 |
+
image = np.array(Image.open(buf))
|
| 135 |
+
return image
|
| 136 |
+
|
| 137 |
+
with gr.Row():
|
| 138 |
+
elo_ratings = gr.Image(label="ELO Ratings")
|
| 139 |
+
|
| 140 |
+
leaderboard.load(get_results, inputs=[], outputs=[elo_ratings])
|
| 141 |
+
|
| 142 |
+
demo = gr.TabbedInterface([evaluation, leaderboard], ["Arena", "Leaderboard"])
|
| 143 |
+
|
| 144 |
+
# Launch the interface
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
demo.launch(show_error=True)
|
| 147 |
|
|
|
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
matplotlib
|
| 2 |
+
numpy
|
| 3 |
+
gradio
|
| 4 |
+
datasets
|
| 5 |
+
requests
|
score_db.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import io
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import List
|
| 9 |
+
|
| 10 |
+
import pandas as pd
|
| 11 |
+
from fastapi import Response
|
| 12 |
+
from modal import web_endpoint
|
| 13 |
+
import modal
|
| 14 |
+
from pydantic import BaseModel
|
| 15 |
+
|
| 16 |
+
from rating import compute_mle_elo
|
| 17 |
+
|
| 18 |
+
# -----------------------
|
| 19 |
+
# Data Model Definition
|
| 20 |
+
# -----------------------
|
| 21 |
+
class ExperienceEnum(int, Enum):
|
| 22 |
+
novice = 1
|
| 23 |
+
intermediate = 2
|
| 24 |
+
expert = 3
|
| 25 |
+
|
| 26 |
+
class Winner(str, Enum):
|
| 27 |
+
model_a = "model_a"
|
| 28 |
+
model_b = "model_b"
|
| 29 |
+
tie = "tie"
|
| 30 |
+
|
| 31 |
+
class Model(str, Enum):
|
| 32 |
+
porestar_deepfault_unet_baseline_1 = "porestar/deepfault-unet-baseline-1"
|
| 33 |
+
porestar_deepfault_unet_baseline_2 = "porestar/deepfault-unet-baseline-2"
|
| 34 |
+
|
| 35 |
+
class Battle(BaseModel):
|
| 36 |
+
model_a: Model
|
| 37 |
+
model_b: Model
|
| 38 |
+
winner: Winner
|
| 39 |
+
judge: str
|
| 40 |
+
image_idx: int
|
| 41 |
+
experience: ExperienceEnum = ExperienceEnum.novice
|
| 42 |
+
tstamp: str = str(datetime.now())
|
| 43 |
+
|
| 44 |
+
class EloRating(BaseModel):
|
| 45 |
+
model: Model
|
| 46 |
+
elo_rating: float
|
| 47 |
+
|
| 48 |
+
# -----------------------
|
| 49 |
+
# Modal Configuration
|
| 50 |
+
# -----------------------
|
| 51 |
+
|
| 52 |
+
# Create a volume to persist data
|
| 53 |
+
data_volume = modal.Volume.from_name("seisbase-data", create_if_missing=True)
|
| 54 |
+
|
| 55 |
+
JSON_FILE_PATH = Path("/data/battles.json")
|
| 56 |
+
RESULTS_FILE_PATH = Path("/data/ratings.csv")
|
| 57 |
+
|
| 58 |
+
app_image = modal.Image.debian_slim(python_version="3.10").pip_install("pandas", "scikit-learn", "tqdm", "sympy")
|
| 59 |
+
|
| 60 |
+
app = modal.App(
|
| 61 |
+
image=app_image,
|
| 62 |
+
name="seisbase-eval",
|
| 63 |
+
volumes={"/data": data_volume},
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
def ensure_json_file():
|
| 67 |
+
"""Ensure the JSON file exists and is initialized with an empty array if necessary."""
|
| 68 |
+
if not os.path.exists(JSON_FILE_PATH):
|
| 69 |
+
JSON_FILE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
| 70 |
+
with open(JSON_FILE_PATH, "w") as f:
|
| 71 |
+
json.dump([], f)
|
| 72 |
+
|
| 73 |
+
def append_to_json_file(data):
|
| 74 |
+
"""Append data to the JSON file."""
|
| 75 |
+
ensure_json_file()
|
| 76 |
+
try:
|
| 77 |
+
with open(JSON_FILE_PATH, "r+") as f:
|
| 78 |
+
try:
|
| 79 |
+
battles = json.load(f)
|
| 80 |
+
except json.JSONDecodeError:
|
| 81 |
+
# Reset the file if corrupted
|
| 82 |
+
battles = []
|
| 83 |
+
battles.append(data)
|
| 84 |
+
f.seek(0)
|
| 85 |
+
json.dump(battles, f, indent=4)
|
| 86 |
+
f.truncate()
|
| 87 |
+
except Exception as e:
|
| 88 |
+
raise RuntimeError(f"Failed to append data to JSON file: {e}")
|
| 89 |
+
|
| 90 |
+
def read_json_file():
|
| 91 |
+
"""Read data from the JSON file."""
|
| 92 |
+
ensure_json_file()
|
| 93 |
+
try:
|
| 94 |
+
with open(JSON_FILE_PATH, "r") as f:
|
| 95 |
+
try:
|
| 96 |
+
return json.load(f)
|
| 97 |
+
except json.JSONDecodeError:
|
| 98 |
+
return [] # Return an empty list if the file is corrupted
|
| 99 |
+
except Exception as e:
|
| 100 |
+
raise RuntimeError(f"Failed to read JSON file: {e}")
|
| 101 |
+
|
| 102 |
+
@app.function()
|
| 103 |
+
@web_endpoint(method="POST", docs=True)
|
| 104 |
+
def add_battle(battle: Battle):
|
| 105 |
+
"""Add a new battle to the JSON file."""
|
| 106 |
+
append_to_json_file(battle.dict())
|
| 107 |
+
return {"status": "success", "battle": battle.dict()}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.function()
|
| 111 |
+
@web_endpoint(method="GET", docs=True)
|
| 112 |
+
def export_csv():
|
| 113 |
+
"""Fetch all battles and return as CSV."""
|
| 114 |
+
battles = read_json_file()
|
| 115 |
+
|
| 116 |
+
# Create CSV in memory
|
| 117 |
+
output = io.StringIO()
|
| 118 |
+
writer = csv.DictWriter(output, fieldnames=["model_a", "model_b", "winner", "judge", "imaged_idx", "experience", "tstamp"])
|
| 119 |
+
writer.writeheader()
|
| 120 |
+
writer.writerows(battles)
|
| 121 |
+
|
| 122 |
+
csv_data = output.getvalue()
|
| 123 |
+
return Response(content=csv_data, media_type="text/csv")
|
| 124 |
+
|
| 125 |
+
@app.function()
|
| 126 |
+
@web_endpoint(method="GET", docs=True)
|
| 127 |
+
def compute_ratings() -> List[EloRating]:
|
| 128 |
+
"""Compute ratings from battles."""
|
| 129 |
+
battles = pd.read_json(JSON_FILE_PATH, dtype=[str, str, str, str, int, int, str]).sort_values(ascending=True, by=["tstamp"]).reset_index(drop=True)
|
| 130 |
+
elo_mle_ratings = compute_mle_elo(battles)
|
| 131 |
+
elo_mle_ratings.to_csv(RESULTS_FILE_PATH)
|
| 132 |
+
|
| 133 |
+
df = pd.read_csv(RESULTS_FILE_PATH)
|
| 134 |
+
df.columns = ["Model", "Elo rating"]
|
| 135 |
+
df = df.sort_values("Elo rating", ascending=False).reset_index(drop=True)
|
| 136 |
+
scores = []
|
| 137 |
+
for i in range(len(df)):
|
| 138 |
+
scores.append(EloRating(model=df["Model"][i], elo_rating=df["Elo rating"][i]))
|
| 139 |
+
return scores
|
| 140 |
+
|
| 141 |
+
@app.local_entrypoint()
|
| 142 |
+
def main():
|
| 143 |
+
print("Local entrypoint running. Check endpoints for functionality.")
|