multilabel-dockerfile-model / scripts /12.2_predict_multilabel_file.py
LeeSek's picture
Add scripts
097a740 verified
# 12.2_predict_multilabel_file.py
# Użycie: python scripts/12.2_predict_multilabel_file.py test/Dockerfile --debug
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import sys
from pathlib import Path
import numpy as np
import json
# === Ścieżki
MODEL_DIR = Path("models/multilabel/")
TOP_RULES_PATH = Path("data/metadata/top_rules.json")
MAX_LENGTH = 512
THRESHOLD = 0.5 # Próg detekcji
# === Załaduj reguły
def load_labels():
with open(TOP_RULES_PATH, encoding="utf-8") as f:
return json.load(f)
# === Załaduj model i tokenizer
def load_model_and_tokenizer():
if MODEL_DIR.exists():
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
else:
raise FileNotFoundError(f"❌ Nie znaleziono katalogu z modelem: {MODEL_DIR}")
model.eval()
return tokenizer, model
# === Predykcja
def predict(filepath: Path, tokenizer, model, labels, threshold=THRESHOLD, debug=False):
text = filepath.read_text(encoding="utf-8")
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=MAX_LENGTH
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().cpu().numpy()
triggered = [(labels[i], probs[i]) for i in range(len(labels)) if probs[i] > threshold]
top5 = np.argsort(probs)[-5:][::-1]
print(f"\n🧪 Predykcja dla pliku: {filepath.name}")
print(f"📄 Długość pliku: {len(text.splitlines())} linii")
if triggered:
print(f"\n🚨 Wykryte reguły (p > {threshold}):")
for rule, p in triggered:
print(f" - {rule}: {p:.3f}")
else:
print("✅ Brak wykrytych problemów (żadna reguła nie przekroczyła progu)")
if debug:
print("\n🛠 DEBUG INFO:")
print(f"📝 Fragment tekstu:\n{text[:300]}")
print(f"🔢 Tokenów: {len(inputs['input_ids'][0])}")
print(f"📈 Logity: {logits.squeeze().tolist()}")
print("\n🔥 Top 5 predykcji:")
for idx in top5:
print(f" - {labels[idx]}: {probs[idx]:.3f}")
# === Główna funkcja
def main():
if len(sys.argv) < 2:
print("❌ Użycie: python scripts/12.2_predict_multilabel_file.py /ścieżka/do/pliku.Dockerfile [--debug]")
sys.exit(1)
filepath = Path(sys.argv[1])
debug = "--debug" in sys.argv
if not filepath.exists():
print(f"❌ Plik {filepath} nie istnieje.")
sys.exit(1)
labels = load_labels()
tokenizer, model = load_model_and_tokenizer()
predict(filepath, tokenizer, model, labels, debug=debug)
if __name__ == "__main__":
main()