File size: 2,785 Bytes
097a740 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# 12.2_predict_multilabel_file.py
# Użycie: python scripts/12.2_predict_multilabel_file.py test/Dockerfile --debug
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import sys
from pathlib import Path
import numpy as np
import json
# === Ścieżki
MODEL_DIR = Path("models/multilabel/")
TOP_RULES_PATH = Path("data/metadata/top_rules.json")
MAX_LENGTH = 512
THRESHOLD = 0.5 # Próg detekcji
# === Załaduj reguły
def load_labels():
with open(TOP_RULES_PATH, encoding="utf-8") as f:
return json.load(f)
# === Załaduj model i tokenizer
def load_model_and_tokenizer():
if MODEL_DIR.exists():
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
else:
raise FileNotFoundError(f"❌ Nie znaleziono katalogu z modelem: {MODEL_DIR}")
model.eval()
return tokenizer, model
# === Predykcja
def predict(filepath: Path, tokenizer, model, labels, threshold=THRESHOLD, debug=False):
text = filepath.read_text(encoding="utf-8")
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=MAX_LENGTH
)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).squeeze().cpu().numpy()
triggered = [(labels[i], probs[i]) for i in range(len(labels)) if probs[i] > threshold]
top5 = np.argsort(probs)[-5:][::-1]
print(f"\n🧪 Predykcja dla pliku: {filepath.name}")
print(f"📄 Długość pliku: {len(text.splitlines())} linii")
if triggered:
print(f"\n🚨 Wykryte reguły (p > {threshold}):")
for rule, p in triggered:
print(f" - {rule}: {p:.3f}")
else:
print("✅ Brak wykrytych problemów (żadna reguła nie przekroczyła progu)")
if debug:
print("\n🛠 DEBUG INFO:")
print(f"📝 Fragment tekstu:\n{text[:300]}")
print(f"🔢 Tokenów: {len(inputs['input_ids'][0])}")
print(f"📈 Logity: {logits.squeeze().tolist()}")
print("\n🔥 Top 5 predykcji:")
for idx in top5:
print(f" - {labels[idx]}: {probs[idx]:.3f}")
# === Główna funkcja
def main():
if len(sys.argv) < 2:
print("❌ Użycie: python scripts/12.2_predict_multilabel_file.py /ścieżka/do/pliku.Dockerfile [--debug]")
sys.exit(1)
filepath = Path(sys.argv[1])
debug = "--debug" in sys.argv
if not filepath.exists():
print(f"❌ Plik {filepath} nie istnieje.")
sys.exit(1)
labels = load_labels()
tokenizer, model = load_model_and_tokenizer()
predict(filepath, tokenizer, model, labels, debug=debug)
if __name__ == "__main__":
main()
|