MOSAIC / app.py
BaggerOfWords's picture
Moved some wrappers
fdf4266
raw
history blame
10.1 kB
import gradio as gr
from mosaic import Mosaic # adjust import as needed
import spaces
import traceback
from transformers import AutoModelForCausalLM
import torch
# Maximum number of model textboxes
MAX_MODELS = 10
# Cache for loaded models to reuse
LOADED_MODELS = {}
GPT_CONFIG_MODELS = [
"openai-community/gpt2-large",
"openai-community/gpt2-medium",
"openai-community/gpt2"
]
Falcon_CONFIG_MODELS = [
"tiiuae/Falcon3-10B-Base",
"tiiuae/Falcon3-10B-Instruct",
"tiiuae/Falcon3-7B-Base",
"tiiuae/Falcon3-7B-Instruct"
]
# Increase model slots
def update_textboxes(n_visible):
if n_visible < MAX_MODELS:
n_visible += 1
tb_updates = [gr.update(visible=(i < n_visible)) for i in range(MAX_MODELS)]
btn_updates = [gr.update(visible=(i < n_visible)) for i in range(MAX_MODELS)]
status_updates = [gr.update(visible=(i < n_visible)) for i in range(MAX_MODELS)]
return (n_visible, *tb_updates, *btn_updates, *status_updates)
# Decrease model slots and clear removed entries
def remove_textboxes(n_visible):
old = n_visible
if n_visible > 2:
n_visible -= 1
new = n_visible
# Remove cached models for slots now hidden
for idx in range(new, old):
LOADED_MODELS.pop(idx+1, None)
tb_updates, btn_updates, status_updates = [], [], []
for i in range(MAX_MODELS):
if i < n_visible:
tb_updates.append(gr.update(visible=True))
btn_updates.append(gr.update(visible=True))
status_updates.append(gr.update(visible=True))
else:
tb_updates.append(gr.update(visible=False, value=""))
btn_updates.append(gr.update(visible=False))
status_updates.append(gr.update(visible=False, value="Not loaded"))
return (n_visible, *tb_updates, *btn_updates, *status_updates)
def apply_config1():
"""
Returns:
- new n_visible (number of boxes to show)
- new values & visibility for each model textbox
- new visibility for each Load button & status box
"""
n_vis = len(GPT_CONFIG_MODELS)
tb_updates, btn_updates, status_updates = [], [], []
for i in range(MAX_MODELS):
if i < n_vis:
# show this slot, set its value from CONFIG_MODELS
tb_updates.append(gr.update(visible=True, value=GPT_CONFIG_MODELS[i]))
btn_updates.append(gr.update(visible=True))
status_updates.append(gr.update(visible=True, value="Not loaded"))
else:
# hide all others
tb_updates.append(gr.update(visible=False, value=""))
btn_updates.append(gr.update(visible=False))
status_updates.append(gr.update(visible=False, value="Not loaded"))
# Return in the same shape as your update_textboxes/remove_textboxes:
# (n_models_state, *all textboxes, *all load buttons, *all status boxes)
return (n_vis, *tb_updates, *btn_updates, *status_updates)
def apply_config2():
"""
Returns:
- new n_visible (number of boxes to show)
- new values & visibility for each model textbox
- new visibility for each Load button & status box
"""
n_vis = len(Falcon_CONFIG_MODELS)
tb_updates, btn_updates, status_updates = [], [], []
for i in range(MAX_MODELS):
if i < n_vis:
# show this slot, set its value from CONFIG_MODELS
tb_updates.append(gr.update(visible=True, value=Falcon_CONFIG_MODELS[i]))
btn_updates.append(gr.update(visible=True))
status_updates.append(gr.update(visible=True, value="Not loaded"))
else:
# hide all others
tb_updates.append(gr.update(visible=False, value=""))
btn_updates.append(gr.update(visible=False))
status_updates.append(gr.update(visible=False, value="Not loaded"))
# Return in the same shape as your update_textboxes/remove_textboxes:
# (n_models_state, *all textboxes, *all load buttons, *all status boxes)
return (n_vis, *tb_updates, *btn_updates, *status_updates)
@spaces.GPU()
# Load a single model and report status
def load_single_model(model_path, use_bfloat16=True):
try:
repo = model_path
if not repo:
return "Error: No path provided"
if repo in LOADED_MODELS:
return "Loaded"
# actual load; may raise
model = AutoModelForCausalLM.from_pretrained(
repo,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.bfloat16 if use_bfloat16 else torch.float32,
)
model.eval()
LOADED_MODELS[repo] = model
return "Loaded"
except Exception as e:
return f"Error loading model: {e}"
# Determine interactive state for Run button
def check_all_loaded(n_visible, *status_texts):
# status_texts are strings: "Loaded" indicates success
needed = status_texts[:n_visible]
if all(s == "Loaded" for s in needed):
return gr.update(interactive=True)
return gr.update(interactive=False)
def run_scoring(input_text, *args):
"""
args: first MAX_MODELS entries are model paths, followed by threshold_choice and custom_threshold
"""
try:
# unpack
models = [m.strip() for m in args[:MAX_MODELS] if m.strip()]
threshold_choice = args[MAX_MODELS]
custom_threshold = args[MAX_MODELS+1]
if len(models) < 2:
return "Please enter at least two model paths.", None, None
threshold = 0.0 if threshold_choice == "default" else custom_threshold
mosaic_instance = Mosaic(model_name_or_paths=models, one_model_mode=False, loaded_models=LOADED_MODELS)
final_score = mosaic_instance.compute_end_score(input_text)
msg = "This text was probably generated." if final_score < threshold else "This text is likely human-written."
return msg, final_score, threshold
except Exception as e:
tb = traceback.format_exc()
return f"Error: {e}\n{tb}", None, None
# Build Blocks UI
demo = gr.Blocks()
with demo:
gr.Markdown("# MOSAIC Scoring App")
with gr.Row():
input_text = gr.Textbox(lines=10, placeholder="Enter text here...", label="Input Text")
with gr.Column():
gr.Markdown("**⚠️ Please make sure all models have the same tokenizer or it won’t work.**")
gr.Markdown("### Model Paths (at least 2 required)")
n_models_state = gr.State(4)
model_inputs, load_buttons, status_boxes = [], [], []
for i in range(1, MAX_MODELS+1):
with gr.Row():
tb = gr.Textbox(label=f"Model {i} Path", value="" if i > 4 else None, visible=(i <= 4))
btn = gr.Button("Load", elem_id=f"load_{i}", visible=(i <= 4))
status = gr.Textbox(label="Loading status", value="Not loaded", interactive=False, visible=(i <= 4))
btn.click(
fn=load_single_model,
inputs=[tb, gr.State(i)],
outputs=status
)
model_inputs.append(tb)
load_buttons.append(btn)
status_boxes.append(status)
with gr.Row():
plus = gr.Button("Add model slot", elem_id="plus_button")
minus = gr.Button("Remove model slot", elem_id="minus_button")
config1_btn = gr.Button("Try Basic gpt Configuration")
config2_btn = gr.Button("Try Falcon models Configuration")
plus.click(
fn=update_textboxes,
inputs=n_models_state,
outputs=[n_models_state, *model_inputs, *load_buttons, *status_boxes]
)
minus.click(
fn=remove_textboxes,
inputs=n_models_state,
outputs=[n_models_state, *model_inputs, *load_buttons, *status_boxes]
)
config1_btn.click(
fn=apply_config1,
inputs=None, # no inputs needed
outputs=[ # must match order:
n_models_state, # 1️⃣ the new visible‑count State
*model_inputs, # 2️⃣ your list of 10 Textboxes
*load_buttons, # 3️⃣ your list of 10 Load Buttons
*status_boxes # 4️⃣ your list of 10 Status Textboxes
]
)
config2_btn.click(
fn=apply_config2,
inputs=None, # no inputs needed
outputs=[ # must match order:
n_models_state, # 1️⃣ the new visible‑count State
*model_inputs, # 2️⃣ your list of 10 Textboxes
*load_buttons, # 3️⃣ your list of 10 Load Buttons
*status_boxes # 4️⃣ your list of 10 Status Textboxes
]
)
with gr.Row():
threshold_choice = gr.Radio(choices=["default", "custom"], value="default", label="Threshold Choice")
custom_threshold = gr.Number(value=0.0, label="Custom Threshold (if 'custom' selected)")
with gr.Row():
output_message = gr.Textbox(label="Result Message")
output_score = gr.Number(label="Final Score")
output_threshold = gr.Number(label="Threshold Used")
gr.Markdown("**⚠️ All models need to be loaded before scoring.**")
run_button = gr.Button("Run Scoring", interactive=False)
# Enable Run button when all statuses reflect "Loaded"
for status in status_boxes:
status.change(
fn=check_all_loaded,
inputs=[n_models_state, *status_boxes],
outputs=run_button
)
n_models_state.change(
fn=check_all_loaded,
inputs=[n_models_state, *status_boxes],
outputs=run_button
)
run_button.click(
fn=run_scoring,
inputs=[input_text, *model_inputs, threshold_choice, custom_threshold],
outputs=[output_message, output_score, output_threshold]
)
# Launch
demo.launch()