Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Header, Request, Depends, UploadFile, File | |
from pydantic import BaseModel | |
import os | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
from supabase import create_client, Client | |
from PIL import Image | |
import re | |
from init_data_py import InitData | |
from deep_translator import GoogleTranslator | |
import psutil | |
import time | |
from datetime import timedelta, datetime, timezone | |
import google.generativeai as genai | |
from fastapi.responses import JSONResponse | |
from starlette.middleware.base import BaseHTTPMiddleware | |
import base64 | |
import json | |
from urllib.parse import parse_qs | |
import secrets | |
import asyncio | |
app = FastAPI( | |
title="Spam Detection API", | |
description="API for text spam detection and AI content generation", | |
version="2.1", | |
docs_url="/docs" | |
) | |
def translate_to_russian(text): | |
try: | |
return GoogleTranslator(source="auto", target="ru").translate(text=text) | |
except: | |
return text | |
TIME_WINDOW = 60 | |
RATE_LIMIT = 1000 | |
BLOCK_TIME = 60 | |
request_counts = {} | |
block_times = {} | |
class RateLimitMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request: Request, call_next): | |
client_ip = request.client.host | |
current_time = time.time() | |
if client_ip in block_times: | |
remaining_time = BLOCK_TIME - (current_time - block_times[client_ip]) | |
if remaining_time > 0: | |
return JSONResponse( | |
status_code=429, | |
content={"ok": False, "error": f"Too Many Requests, try again in {int(remaining_time)} seconds"} | |
) | |
else: | |
del block_times[client_ip] | |
if client_ip in request_counts: | |
request_counts[client_ip] = [t for t in request_counts[client_ip] if current_time - t < TIME_WINDOW] | |
if len(request_counts[client_ip]) >= RATE_LIMIT: | |
block_times[client_ip] = current_time | |
return JSONResponse( | |
status_code=429, | |
content={"ok": False, "error": f"Too Many Requests, try again in {BLOCK_TIME} seconds"} | |
) | |
request_counts.setdefault(client_ip, []).append(current_time) | |
response = await call_next(request) | |
return response | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
SUPABASE_URL = os.getenv("SUPABASE_URL") | |
SUPABASE_KEY = os.getenv("SUPABASE_KEY") | |
BOT_TOKEN = os.getenv("bot_token") | |
GEMINI_KEY = os.getenv("Gemini") | |
OWNERSKEY = os.getenv("ownersKey") | |
supabase: Client = create_client(SUPABASE_URL, SUPABASE_KEY) | |
models = {} | |
tokenizers = {} | |
KEY_VALIDITY_SECONDS = 120 | |
class GenerateKeyRequest(BaseModel): | |
tg_id: str | |
async def generate_key(request: Request, body: GenerateKeyRequest): | |
tg_id = body.tg_id | |
if not tg_id: | |
raise HTTPException(status_code=400, detail="Telegram ID is required") | |
encryption_key = secrets.token_urlsafe(32) | |
expires_at = datetime.now(timezone.utc) + timedelta(seconds=KEY_VALIDITY_SECONDS) | |
data = {"api_address": tg_id, "encryption_key": encryption_key, "expires_at": expires_at.isoformat()} | |
supabase.table("encryption_keys").insert(data).execute() | |
return JSONResponse( | |
status_code=200, | |
content={"ok": True, "encryption_key": encryption_key, "expires_at": expires_at.isoformat(), "api_address": tg_id} | |
) | |
import re | |
def clean_text(text): | |
text = text.strip().replace("\n", " ") | |
text = re.sub(r"[^\w\s,.!?]", "", text, flags=re.UNICODE) | |
text = re.sub(r"\d+", "", text) | |
text = re.sub(r"[!?]", "", text) | |
if re.search(r"[а-яА-Я]", text): | |
text = re.sub(r"\b(?!@|https?://|www\.)[a-zA-Z]+\b", "", text) | |
return re.sub(r"\s+", " ", text).strip().lower() | |
class MessageRequest(BaseModel): | |
message: str | |
model_name: str | |
telegram_id: str = None | |
def initialize_models(): | |
for model_name in [ | |
"ru-spam/ruSpamNS_v9_Detector", | |
"ru-spam/ruSpamNS_v9_Precision", | |
"ru-spam/ruSpamNS_v9_Detector_tiny", | |
"ru-spam/ruSpamNS_v9_Precision_tiny", | |
]: | |
models[model_name] = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1, use_auth_token=huggingface_token).to(device).eval() | |
tokenizers[model_name] = AutoTokenizer.from_pretrained(model_name, use_auth_token=huggingface_token) | |
initialize_models() | |
classifier = pipeline("image-classification", model="AdamCodd/vit-base-nsfw-detector") | |
genai.configure(api_key=GEMINI_KEY) | |
def check_webapp_signature(init_data: str) -> bool: | |
try: | |
return InitData.parse(init_data).validate(BOT_TOKEN, lifetime=120) | |
except Exception: | |
return False | |
def contains_russian(text): | |
return bool(re.search(r"[а-яА-Я]", text)) | |
async def get_balance(api_key): | |
if api_key == OWNERSKEY: | |
return "Owner" | |
response = supabase.table("tokens").select("balance").eq("token", api_key).execute() | |
return response.data[0].get("balance", 0.0) if response.data else 0.0 | |
async def deduct_balance(api_key, tokens_used): | |
if api_key == OWNERSKEY: | |
return True | |
if "signature" in api_key and check_webapp_signature(api_key): | |
return True | |
balance = await get_balance(api_key) | |
cost = (tokens_used / 128) * 0.04 | |
if balance < cost: | |
return False | |
supabase.table("tokens").update({"balance": balance - cost}).eq("token", api_key).execute() | |
return True | |
def xor_cipher(input_text, key): | |
return "".join(chr(ord(c) ^ ord(key[i % len(key)])) for i, c in enumerate(input_text)) | |
def decrypt_data(data, key): | |
return xor_cipher(base64.b64decode(data).decode(), key) | |
async def get_last_encryption_key(tg_id: str) -> str: | |
result = supabase.table("encryption_keys").select("*").eq("api_address", tg_id).order("expires_at", desc=True).execute() | |
data = result.data | |
if not data: | |
return None | |
latest_key = data[0]["encryption_key"] | |
if len(data) > 1: | |
ids_to_delete = [record["id"] for record in data[1:]] | |
supabase.table("encryption_keys").delete().in_("id", ids_to_delete).execute() | |
return latest_key | |
class CacheRequestBodyMiddleware(BaseHTTPMiddleware): | |
async def dispatch(self, request: Request, call_next): | |
body = await request.body() | |
request.state.body = body | |
response = await call_next(request) | |
return response | |
app.add_middleware(CacheRequestBodyMiddleware) | |
app.add_middleware(RateLimitMiddleware) | |
async def get_telegram_id(request: Request) -> str: | |
try: | |
body_data = json.loads(request.state.body) | |
return body_data.get("telegram_id") | |
except: | |
return None | |
async def verify_api_key(request: Request, api_key: str = Header(...), tg_id: str = Depends(get_telegram_id)): | |
if api_key == OWNERSKEY: | |
return api_key | |
last_key = await get_last_encryption_key(tg_id) if tg_id else None | |
if last_key: | |
try: | |
decrypted = decrypt_data(api_key, last_key) | |
except Exception: | |
decrypted = None | |
if decrypted and "signature" in decrypted: | |
if not check_webapp_signature(decrypted): | |
raise HTTPException(403, "Invalid Telegram init data") | |
decoded_data = parse_qs(decrypted) | |
user_json = decoded_data.get("user", [None])[0] | |
user = json.loads(user_json) | |
user_id = user.get("id") | |
if str(user_id) != str(tg_id): | |
raise HTTPException(403, "Telegram ID mismatch") | |
return decrypted | |
else: | |
if not await get_balance(api_key): | |
raise HTTPException(403, "Invalid API key or insufficient balance") | |
return api_key | |
else: | |
if "signature" in api_key: | |
if not check_webapp_signature(api_key): | |
raise HTTPException(403, "Invalid Telegram init data") | |
init_obj = InitData.parse(api_key) | |
if str(init_obj.id) != str(tg_id): | |
raise HTTPException(403, "Telegram ID mismatch") | |
return api_key | |
if not await get_balance(api_key): | |
raise HTTPException(403, "Invalid API key or insufficient balance") | |
return api_key | |
async def custom_http_exception_handler(request: Request, exc: HTTPException): | |
return JSONResponse(status_code=exc.status_code, content={"ok": False, "error": exc.detail}) | |
async def global_exception_handler(request: Request, exc: Exception): | |
return JSONResponse(status_code=500, content={"ok": False, "error": str(exc)}) | |
async def root(request: Request): | |
return {"ok": True, "status": "OK"} | |
async def generative_message(request: Request, message: str, api_key: str = Depends(verify_api_key)): | |
model = genai.GenerativeModel( | |
model_name="tunedModels/outputfile-gnd8p2ou5u42", | |
generation_config={"temperature": 0, "top_p": 0.95, "top_k": 40, "max_output_tokens": 8192} | |
) | |
response = model.start_chat().send_message(message).text | |
return {"ok": True, "response": response} | |
async def status(request: Request): | |
uptime = str(timedelta(seconds=int(time.time() - psutil.boot_time()))) | |
cpu = psutil.cpu_percent() | |
memory = psutil.virtual_memory().percent | |
disk = psutil.disk_usage("/").percent | |
return {"ok": True, "uptime": uptime, "cpu": cpu, "memory": memory, "disk": disk} | |
async def check_image(request: Request, file: UploadFile = File(...), api_key: str = Depends(verify_api_key)): | |
result = classifier(Image.open(file.file))[0] | |
is_nsfw = result["label"] == "NSFW" | |
confidence = result["score"] | |
return {"ok": True, "is_nsfw": is_nsfw, "confidence": confidence} | |
async def check_spam(request: Request, data: MessageRequest, api_key: str = Depends(verify_api_key)): | |
start_time = time.time() | |
model_name = "ru-spam/ruSpamNS_v9_Detector_tiny" if "tiny" in data.model_name.lower() else "ru-spam/ruSpamNS_v9_Detector" | |
precision_model = model_name.replace("Detector", "Precision") | |
message = translate_to_russian(data.message) if not contains_russian(data.message) else data.message | |
cleaned = clean_text(message) | |
tokenizer = tokenizers[model_name] | |
encoding = tokenizer(cleaned, max_length=256, truncation=True, return_tensors="pt").to(device) | |
if data.telegram_id: | |
print(f"Telegram ID: {data.telegram_id}") | |
if api_key != OWNERSKEY: | |
if not await deduct_balance(api_key, len(encoding["input_ids"][0])): | |
raise HTTPException(403, "Insufficient funds") | |
with torch.no_grad(): | |
pred = torch.sigmoid(models[model_name](**encoding).logits).item() | |
if 0.5 <= pred < 0.75: | |
precision_encoding = tokenizers[precision_model](cleaned, max_length=256, truncation=True, return_tensors="pt").to(device) | |
if api_key != OWNERSKEY: | |
if not await deduct_balance(api_key, len(precision_encoding["input_ids"][0])): | |
raise HTTPException(403, "Insufficient funds") | |
with torch.no_grad(): | |
pred = torch.sigmoid(models[precision_model](**precision_encoding).logits).item() | |
result_spam = 1 if pred >= 0.5 else 0 | |
tokens_used = len(encoding["input_ids"][0]) + (len(precision_encoding["input_ids"][0]) if "precision_encoding" in locals() else 0) | |
cost = (tokens_used / 128) * 0.04 | |
processing_time = time.time() - start_time | |
remaining_balance = await get_balance(api_key) if api_key != OWNERSKEY else "Owner" | |
return {"ok": True, "model": data.model_name, "balance": remaining_balance, "is_spam": result_spam, "confidence": pred, "tokens_used": tokens_used, "cost": cost, "processing_time": processing_time} | |
async def balance(request: Request, api_key: str = Depends(verify_api_key)): | |
balance_value = await get_balance(api_key) | |
return {"ok": True, "balance": balance_value} | |
class GeminiSpamRequest(BaseModel): | |
message: str | |
telegram_id: str = None | |
async def gemini_spam(data: GeminiSpamRequest, request: Request, api_key: str = Depends(verify_api_key)): | |
prompt = f""" | |
<task> | |
<description>Анализируй входное сообщение и определи, является ли оно спамом.</description> | |
<input>{data.message}</input> | |
<output_format> | |
<spam>1</spam> | |
<not_spam>0</not_spam> | |
</output_format> | |
<criteria>Сообщение должно быть полным, законченным и осмысленным, а не отдельными словами или фразами, чтобы считаться спамом. Например "Онлaйн рaбoтa" не спам.</criteria> | |
<strict>Выводи только число без лишних символов и текста.</strict> | |
</task> | |
""" | |
def call_gemini(): | |
model = genai.GenerativeModel( | |
model_name="gemini-2.0-flash", | |
generation_config={"temperature": 1, "top_p": 0.95, "top_k": 40, "max_output_tokens": 8192, "response_mime_type": "text/plain"} | |
) | |
chat_session = model.start_chat(history=[]) | |
return chat_session.send_message(prompt) | |
loop = asyncio.get_event_loop() | |
response = await loop.run_in_executor(None, call_gemini) | |
result_text = response.text.strip() | |
if result_text == "1": | |
is_spam = True | |
elif result_text == "0": | |
is_spam = False | |
else: | |
is_spam = True if "1" in result_text else False | |
return {"ok": True, "is_spam": is_spam, "response": result_text} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=5002) | |