File size: 2,785 Bytes
097a740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# 12.2_predict_multilabel_file.py
# Użycie: python scripts/12.2_predict_multilabel_file.py test/Dockerfile --debug

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import sys
from pathlib import Path
import numpy as np
import json

# === Ścieżki
MODEL_DIR = Path("models/multilabel/")
TOP_RULES_PATH = Path("data/metadata/top_rules.json")
MAX_LENGTH = 512
THRESHOLD = 0.5  # Próg detekcji

# === Załaduj reguły
def load_labels():
    with open(TOP_RULES_PATH, encoding="utf-8") as f:
        return json.load(f)

# === Załaduj model i tokenizer
def load_model_and_tokenizer():
    if MODEL_DIR.exists():
        tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
        model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
    else:
        raise FileNotFoundError(f"❌ Nie znaleziono katalogu z modelem: {MODEL_DIR}")
    model.eval()
    return tokenizer, model

# === Predykcja
def predict(filepath: Path, tokenizer, model, labels, threshold=THRESHOLD, debug=False):
    text = filepath.read_text(encoding="utf-8")

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH
    )

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.sigmoid(logits).squeeze().cpu().numpy()

    triggered = [(labels[i], probs[i]) for i in range(len(labels)) if probs[i] > threshold]
    top5 = np.argsort(probs)[-5:][::-1]

    print(f"\n🧪 Predykcja dla pliku: {filepath.name}")
    print(f"📄 Długość pliku: {len(text.splitlines())} linii")

    if triggered:
        print(f"\n🚨 Wykryte reguły (p > {threshold}):")
        for rule, p in triggered:
            print(f" - {rule}: {p:.3f}")
    else:
        print("✅ Brak wykrytych problemów (żadna reguła nie przekroczyła progu)")

    if debug:
        print("\n🛠 DEBUG INFO:")
        print(f"📝 Fragment tekstu:\n{text[:300]}")
        print(f"🔢 Tokenów: {len(inputs['input_ids'][0])}")
        print(f"📈 Logity: {logits.squeeze().tolist()}")
        print("\n🔥 Top 5 predykcji:")
        for idx in top5:
            print(f" - {labels[idx]}: {probs[idx]:.3f}")

# === Główna funkcja
def main():
    if len(sys.argv) < 2:
        print("❌ Użycie: python scripts/12.2_predict_multilabel_file.py /ścieżka/do/pliku.Dockerfile [--debug]")
        sys.exit(1)

    filepath = Path(sys.argv[1])
    debug = "--debug" in sys.argv

    if not filepath.exists():
        print(f"❌ Plik {filepath} nie istnieje.")
        sys.exit(1)

    labels = load_labels()
    tokenizer, model = load_model_and_tokenizer()
    predict(filepath, tokenizer, model, labels, debug=debug)

if __name__ == "__main__":
    main()