TheFrenchDemos's picture
implemented detection
7f97da4
raw
history blame
10.6 kB
"""
Main Flask application for the watermark detection web interface.
"""
from flask import Flask, render_template, request, jsonify, Response, stream_with_context
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ
from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator
from .utils import get_token_details, template_prompt
CACHE_DIR = "wm_interactive/static/hf_cache"
def convert_nan_to_null(obj):
"""Convert NaN values to null for JSON serialization"""
import math
if isinstance(obj, float) and math.isnan(obj):
return None
elif isinstance(obj, dict):
return {k: convert_nan_to_null(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_nan_to_null(item) for item in obj]
return obj
def set_to_int(value, default_value = None):
try:
return int(value)
except (ValueError, TypeError):
return default_value
def create_detector(detector_type, tokenizer, **kwargs):
"""Create a detector instance based on the specified type."""
detector_map = {
'maryland': MarylandDetector,
'marylandz': MarylandDetectorZ,
'openai': OpenaiDetector,
'openaiz': OpenaiDetectorZ
}
# Validate and set default values for parameters
if 'seed' in kwargs:
kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0)
if 'ngram' in kwargs:
kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1)
detector_class = detector_map.get(detector_type, MarylandDetector)
return detector_class(tokenizer=tokenizer, **kwargs)
def create_app():
app = Flask(__name__,
static_folder='../static',
template_folder='../templates')
# Add zip to Jinja's global context
app.jinja_env.globals.update(zip=zip)
# Pick a model
# model_id = "meta-llama/Llama-3.2-1B-Instruct"
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR)
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu")
# Create default generator
generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0)
@app.route("/", methods=["GET"])
def index():
return render_template("index.html")
@app.route("/tokenize", methods=["POST"])
def tokenize():
try:
data = request.get_json()
if not data:
return jsonify({'error': 'No JSON data received'}), 400
text = data.get('text', '')
params = data.get('params', {})
# Create a detector instance with the provided parameters
detector = create_detector(
detector_type=params.get('detector_type', 'maryland'),
tokenizer=tokenizer,
seed=params.get('seed', 0),
ngram=params.get('ngram', 1)
)
if text:
try:
display_info = get_token_details(text, detector)
# Extract summary stats (last item in display_info)
stats = display_info.pop()
response_data = {
'token_count': len(display_info),
'tokens': [info['token'] for info in display_info],
'colors': [info['color'] for info in display_info],
'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info],
'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info],
'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0,
'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0,
'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5
}
# Convert any NaN values to null before sending
response_data = convert_nan_to_null(response_data)
# Ensure numeric fields have default values if they became null
if response_data['final_score'] is None:
response_data['final_score'] = 0
if response_data['ntoks_scored'] is None:
response_data['ntoks_scored'] = 0
if response_data['final_pvalue'] is None:
response_data['final_pvalue'] = 0.5
return jsonify(response_data)
except Exception as e:
app.logger.error(f'Error processing text: {str(e)}')
return jsonify({'error': f'Error processing text: {str(e)}'}), 500
return jsonify({
'token_count': 0,
'tokens': [],
'colors': [],
'scores': [],
'pvalues': [],
'final_score': 0,
'ntoks_scored': 0,
'final_pvalue': 0.5
})
except Exception as e:
app.logger.error(f'Server error: {str(e)}')
return jsonify({'error': f'Server error: {str(e)}'}), 500
@app.route("/generate", methods=["POST"])
def generate():
try:
data = request.get_json()
if not data:
return jsonify({'error': 'No JSON data received'}), 400
prompt = template_prompt(data.get('prompt', ''))
params = data.get('params', {})
temperature = float(params.get('temperature', 0.8))
def generate_stream():
try:
# Create generator with correct parameters
generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator
generator = generator_class(
model=model,
tokenizer=tokenizer,
ngram=set_to_int(params.get('ngram', 1)),
seed=set_to_int(params.get('seed', 0)),
delta=float(params.get('delta', 2.0)),
)
# Get special tokens to filter out
special_tokens = {
'<|im_start|>', '<|im_end|>',
tokenizer.pad_token, tokenizer.eos_token,
tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None,
tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None
}
special_tokens = {t for t in special_tokens if t is not None}
# Encode prompt
prompt_tokens = tokenizer.encode(prompt)
prompt_size = len(prompt_tokens)
max_gen_len = 100
total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size)
# Initialize generation
tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long()
tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long()
input_text_mask = tokens != model.config.pad_token_id
# Generate token by token
prev_pos = 0
outputs = None # Initialize outputs to None
for cur_pos in range(prompt_size, total_len):
# Get model outputs
outputs = model.forward(
tokens[:, prev_pos:cur_pos],
use_cache=True,
past_key_values=outputs.past_key_values if prev_pos > 0 else None
)
# Sample next token using the generator's sampling method
ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist()
aux = {
'ngram_tokens': ngram_tokens,
'cur_pos': cur_pos,
}
next_token = generator.sample_next(
outputs.logits[:, -1, :],
aux,
temperature=temperature,
top_p=0.9
)
# Check for EOS token
if next_token == model.config.eos_token_id:
break
# Decode and check if it's a special token
new_text = tokenizer.decode([next_token])
if new_text not in special_tokens and not any(st in new_text for st in special_tokens):
yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n"
# Update token and position
tokens[0, cur_pos] = next_token
prev_pos = cur_pos
# Send final complete text, filtering out special tokens
final_tokens = tokens[0, prompt_size:cur_pos+1].tolist()
final_text = tokenizer.decode(final_tokens)
for st in special_tokens:
final_text = final_text.replace(st, '')
yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n"
except Exception as e:
app.logger.error(f'Error generating text: {str(e)}')
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return Response(stream_with_context(generate_stream()), mimetype='text/event-stream')
except Exception as e:
app.logger.error(f'Server error: {str(e)}')
return jsonify({'error': f'Server error: {str(e)}'}), 500
return app
app = create_app()
if __name__ == "__main__":
app.run(host='0.0.0.0', port=7860)