Spaces:
Running
Running
""" | |
Main Flask application for the watermark detection web interface. | |
""" | |
from flask import Flask, render_template, request, jsonify, Response, stream_with_context | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import json | |
from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ | |
from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator | |
from .utils import get_token_details, template_prompt | |
CACHE_DIR = "wm_interactive/static/hf_cache" | |
def convert_nan_to_null(obj): | |
"""Convert NaN values to null for JSON serialization""" | |
import math | |
if isinstance(obj, float) and math.isnan(obj): | |
return None | |
elif isinstance(obj, dict): | |
return {k: convert_nan_to_null(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [convert_nan_to_null(item) for item in obj] | |
return obj | |
def set_to_int(value, default_value = None): | |
try: | |
return int(value) | |
except (ValueError, TypeError): | |
return default_value | |
def create_detector(detector_type, tokenizer, **kwargs): | |
"""Create a detector instance based on the specified type.""" | |
detector_map = { | |
'maryland': MarylandDetector, | |
'marylandz': MarylandDetectorZ, | |
'openai': OpenaiDetector, | |
'openaiz': OpenaiDetectorZ | |
} | |
# Validate and set default values for parameters | |
if 'seed' in kwargs: | |
kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0) | |
if 'ngram' in kwargs: | |
kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1) | |
detector_class = detector_map.get(detector_type, MarylandDetector) | |
return detector_class(tokenizer=tokenizer, **kwargs) | |
def create_app(): | |
app = Flask(__name__, | |
static_folder='../static', | |
template_folder='../templates') | |
# Add zip to Jinja's global context | |
app.jinja_env.globals.update(zip=zip) | |
# Pick a model | |
# model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
model_id = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR) | |
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu") | |
# Create default generator | |
generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0) | |
def index(): | |
return render_template("index.html") | |
def tokenize(): | |
try: | |
data = request.get_json() | |
if not data: | |
return jsonify({'error': 'No JSON data received'}), 400 | |
text = data.get('text', '') | |
params = data.get('params', {}) | |
# Create a detector instance with the provided parameters | |
detector = create_detector( | |
detector_type=params.get('detector_type', 'maryland'), | |
tokenizer=tokenizer, | |
seed=params.get('seed', 0), | |
ngram=params.get('ngram', 1) | |
) | |
if text: | |
try: | |
display_info = get_token_details(text, detector) | |
# Extract summary stats (last item in display_info) | |
stats = display_info.pop() | |
response_data = { | |
'token_count': len(display_info), | |
'tokens': [info['token'] for info in display_info], | |
'colors': [info['color'] for info in display_info], | |
'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info], | |
'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info], | |
'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0, | |
'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0, | |
'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5 | |
} | |
# Convert any NaN values to null before sending | |
response_data = convert_nan_to_null(response_data) | |
# Ensure numeric fields have default values if they became null | |
if response_data['final_score'] is None: | |
response_data['final_score'] = 0 | |
if response_data['ntoks_scored'] is None: | |
response_data['ntoks_scored'] = 0 | |
if response_data['final_pvalue'] is None: | |
response_data['final_pvalue'] = 0.5 | |
return jsonify(response_data) | |
except Exception as e: | |
app.logger.error(f'Error processing text: {str(e)}') | |
return jsonify({'error': f'Error processing text: {str(e)}'}), 500 | |
return jsonify({ | |
'token_count': 0, | |
'tokens': [], | |
'colors': [], | |
'scores': [], | |
'pvalues': [], | |
'final_score': 0, | |
'ntoks_scored': 0, | |
'final_pvalue': 0.5 | |
}) | |
except Exception as e: | |
app.logger.error(f'Server error: {str(e)}') | |
return jsonify({'error': f'Server error: {str(e)}'}), 500 | |
def generate(): | |
try: | |
data = request.get_json() | |
if not data: | |
return jsonify({'error': 'No JSON data received'}), 400 | |
prompt = template_prompt(data.get('prompt', '')) | |
params = data.get('params', {}) | |
temperature = float(params.get('temperature', 0.8)) | |
def generate_stream(): | |
try: | |
# Create generator with correct parameters | |
generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator | |
generator = generator_class( | |
model=model, | |
tokenizer=tokenizer, | |
ngram=set_to_int(params.get('ngram', 1)), | |
seed=set_to_int(params.get('seed', 0)), | |
delta=float(params.get('delta', 2.0)), | |
) | |
# Get special tokens to filter out | |
special_tokens = { | |
'<|im_start|>', '<|im_end|>', | |
tokenizer.pad_token, tokenizer.eos_token, | |
tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None, | |
tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None | |
} | |
special_tokens = {t for t in special_tokens if t is not None} | |
# Encode prompt | |
prompt_tokens = tokenizer.encode(prompt) | |
prompt_size = len(prompt_tokens) | |
max_gen_len = 100 | |
total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size) | |
# Initialize generation | |
tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long() | |
tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long() | |
input_text_mask = tokens != model.config.pad_token_id | |
# Generate token by token | |
prev_pos = 0 | |
outputs = None # Initialize outputs to None | |
for cur_pos in range(prompt_size, total_len): | |
# Get model outputs | |
outputs = model.forward( | |
tokens[:, prev_pos:cur_pos], | |
use_cache=True, | |
past_key_values=outputs.past_key_values if prev_pos > 0 else None | |
) | |
# Sample next token using the generator's sampling method | |
ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist() | |
aux = { | |
'ngram_tokens': ngram_tokens, | |
'cur_pos': cur_pos, | |
} | |
next_token = generator.sample_next( | |
outputs.logits[:, -1, :], | |
aux, | |
temperature=temperature, | |
top_p=0.9 | |
) | |
# Check for EOS token | |
if next_token == model.config.eos_token_id: | |
break | |
# Decode and check if it's a special token | |
new_text = tokenizer.decode([next_token]) | |
if new_text not in special_tokens and not any(st in new_text for st in special_tokens): | |
yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n" | |
# Update token and position | |
tokens[0, cur_pos] = next_token | |
prev_pos = cur_pos | |
# Send final complete text, filtering out special tokens | |
final_tokens = tokens[0, prompt_size:cur_pos+1].tolist() | |
final_text = tokenizer.decode(final_tokens) | |
for st in special_tokens: | |
final_text = final_text.replace(st, '') | |
yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n" | |
except Exception as e: | |
app.logger.error(f'Error generating text: {str(e)}') | |
yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
return Response(stream_with_context(generate_stream()), mimetype='text/event-stream') | |
except Exception as e: | |
app.logger.error(f'Server error: {str(e)}') | |
return jsonify({'error': f'Server error: {str(e)}'}), 500 | |
return app | |
app = create_app() | |
if __name__ == "__main__": | |
app.run(host='0.0.0.0', port=7860) |