add files
Browse files- .gitattributes +2 -0
- LICENSE +21 -0
- data/test.fasta +3 -0
- eval_model.py +85 -0
- output.log +1 -0
- pathoLM.png +3 -0
- utils.py +45 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/test.fasta filter=lfs diff=lfs merge=lfs -text
|
37 |
+
pathoLM.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Sajib Acharjee Dip
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
data/test.fasta
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7f8e8010c816aa6fae67ef6a53e35b2907a5b54e610e1dac7e9912dc20526b40
|
3 |
+
size 11990791
|
eval_model.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
4 |
+
from datasets import Dataset
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef, balanced_accuracy_score, roc_auc_score, confusion_matrix
|
8 |
+
from utils import fasta_to_df
|
9 |
+
|
10 |
+
def compute_metrics(logits, labels):
|
11 |
+
predictions = np.argmax(logits, axis=1)
|
12 |
+
labels = np.array(labels, dtype=int)
|
13 |
+
predictions = np.array(predictions, dtype=int)
|
14 |
+
|
15 |
+
acc = accuracy_score(labels, predictions)
|
16 |
+
f1 = f1_score(labels, predictions, average='weighted')
|
17 |
+
mcc = matthews_corrcoef(labels, predictions)
|
18 |
+
balanced_acc = balanced_accuracy_score(labels, predictions)
|
19 |
+
auc_roc = None
|
20 |
+
|
21 |
+
if len(np.unique(labels)) == 2:
|
22 |
+
probs = np.exp(logits[:, 1]) / np.sum(np.exp(logits), axis=1)
|
23 |
+
auc_roc = roc_auc_score(labels, probs)
|
24 |
+
|
25 |
+
cm = confusion_matrix(labels, predictions)
|
26 |
+
return {
|
27 |
+
'accuracy': acc,
|
28 |
+
'f1_score': f1,
|
29 |
+
'mcc': mcc,
|
30 |
+
'auc_roc': auc_roc,
|
31 |
+
'balanced_accuracy': balanced_acc,
|
32 |
+
'confusion_matrix': cm.tolist()
|
33 |
+
}
|
34 |
+
|
35 |
+
def encode_sequence(sequence, tokenizer, max_length):
|
36 |
+
return tokenizer(sequence, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
|
37 |
+
|
38 |
+
def evaluate(model_path, test_file=None, sequence=None):
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_path, ignore_mismatched_sizes=True).to(device)
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
42 |
+
|
43 |
+
if sequence:
|
44 |
+
inputs = encode_sequence(sequence, tokenizer, tokenizer.model_max_length)
|
45 |
+
with torch.no_grad():
|
46 |
+
outputs = model(**{k: v.to(device) for k, v in inputs.items()})
|
47 |
+
logits = outputs.logits.cpu().numpy()
|
48 |
+
print("Single Sequence Prediction:", np.argmax(logits, axis=1))
|
49 |
+
return
|
50 |
+
|
51 |
+
test_df = fasta_to_df(test_file)
|
52 |
+
label_map = {
|
53 |
+
'non-pathogen': 0,
|
54 |
+
'pathogen': 1
|
55 |
+
}
|
56 |
+
test_df['label'] = test_df['label'].str.lower().map(label_map)
|
57 |
+
dataset = Dataset.from_pandas(test_df)
|
58 |
+
dataset = dataset.map(lambda x: encode_sequence(x['sequence'], tokenizer, tokenizer.model_max_length), batched=True)
|
59 |
+
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
|
60 |
+
|
61 |
+
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
|
62 |
+
logits_list, labels_list = [], []
|
63 |
+
|
64 |
+
model.eval()
|
65 |
+
with torch.no_grad():
|
66 |
+
for batch in dataloader:
|
67 |
+
inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
|
68 |
+
# print(type(batch['label']), batch['label'])
|
69 |
+
labels = np.array(batch['label'])
|
70 |
+
outputs = model(**inputs)
|
71 |
+
logits_list.append(outputs.logits.cpu().numpy())
|
72 |
+
labels_list.append(labels)
|
73 |
+
|
74 |
+
logits = np.concatenate(logits_list, axis=0)
|
75 |
+
labels = np.concatenate(labels_list, axis=0)
|
76 |
+
results = compute_metrics(logits, labels)
|
77 |
+
print("Evaluation Metrics:", results)
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model directory")
|
82 |
+
parser.add_argument("--test_file", type=str, help="Path to the test fasta file")
|
83 |
+
parser.add_argument("--sequence", type=str, help="Single DNA sequence to classify")
|
84 |
+
args = parser.parse_args()
|
85 |
+
evaluate(args.model_path, args.test_file, args.sequence)
|
output.log
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Evaluation Metrics: {'accuracy': 0.6747420367877972, 'f1_score': 0.6143408107901266, 'mcc': 0.3983982161600102, 'auc_roc': 0.26562839560775575, 'balanced_accuracy': 0.6260618867430012, 'confusion_matrix': [[736, 2171], [4, 3776]]}
|
pathoLM.png
ADDED
![]() |
Git LFS Details
|
utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from Bio.Seq import Seq
|
3 |
+
from Bio.SeqRecord import SeqRecord
|
4 |
+
from Bio import SeqIO
|
5 |
+
|
6 |
+
def stratified_sampling(df, sample_size=5000):
|
7 |
+
label_counts = df['label'].value_counts()
|
8 |
+
min_count = label_counts.min()
|
9 |
+
sample_size = min(sample_size, min_count)
|
10 |
+
sampled_df = df.groupby('label').apply(lambda x: x.sample(n=sample_size, random_state=42)).reset_index(drop=True)
|
11 |
+
return sampled_df
|
12 |
+
|
13 |
+
def fasta_to_df(fasta_file):
|
14 |
+
unique_ids = []
|
15 |
+
species = []
|
16 |
+
sequence_lengths = []
|
17 |
+
labels = []
|
18 |
+
fragment_ids = []
|
19 |
+
sequences = []
|
20 |
+
|
21 |
+
for record in SeqIO.parse(fasta_file, "fasta"):
|
22 |
+
unique_ids.append(record.description.split(' ')[0])
|
23 |
+
|
24 |
+
desc_parts = record.description.split(' ', 1)[1] if ' ' in record.description else ''
|
25 |
+
try:
|
26 |
+
desc_parts_dict = {part.split(':')[0].strip(): part.split(':')[1].strip() for part in desc_parts.split('|')}
|
27 |
+
except Exception as e:
|
28 |
+
print(f"Error parsing description for record {record.id}: {e}")
|
29 |
+
continue
|
30 |
+
|
31 |
+
species.append(desc_parts_dict.get('species'))
|
32 |
+
sequence_lengths.append(int(desc_parts_dict.get('sequence_length', 0)))
|
33 |
+
labels.append(desc_parts_dict.get('label'))
|
34 |
+
sequences.append(str(record.seq))
|
35 |
+
|
36 |
+
|
37 |
+
df = pd.DataFrame({
|
38 |
+
'unique_id': unique_ids,
|
39 |
+
'species': species,
|
40 |
+
'sequence_length': sequence_lengths,
|
41 |
+
'label': labels,
|
42 |
+
'sequence': sequences
|
43 |
+
})
|
44 |
+
|
45 |
+
return df
|