Sajib-006 commited on
Commit
28d5b6a
·
verified ·
1 Parent(s): 3a21254
Files changed (7) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. data/test.fasta +3 -0
  4. eval_model.py +85 -0
  5. output.log +1 -0
  6. pathoLM.png +3 -0
  7. utils.py +45 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/test.fasta filter=lfs diff=lfs merge=lfs -text
37
+ pathoLM.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Sajib Acharjee Dip
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
data/test.fasta ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f8e8010c816aa6fae67ef6a53e35b2907a5b54e610e1dac7e9912dc20526b40
3
+ size 11990791
eval_model.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ from datasets import Dataset
5
+ import pandas as pd
6
+ import numpy as np
7
+ from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef, balanced_accuracy_score, roc_auc_score, confusion_matrix
8
+ from utils import fasta_to_df
9
+
10
+ def compute_metrics(logits, labels):
11
+ predictions = np.argmax(logits, axis=1)
12
+ labels = np.array(labels, dtype=int)
13
+ predictions = np.array(predictions, dtype=int)
14
+
15
+ acc = accuracy_score(labels, predictions)
16
+ f1 = f1_score(labels, predictions, average='weighted')
17
+ mcc = matthews_corrcoef(labels, predictions)
18
+ balanced_acc = balanced_accuracy_score(labels, predictions)
19
+ auc_roc = None
20
+
21
+ if len(np.unique(labels)) == 2:
22
+ probs = np.exp(logits[:, 1]) / np.sum(np.exp(logits), axis=1)
23
+ auc_roc = roc_auc_score(labels, probs)
24
+
25
+ cm = confusion_matrix(labels, predictions)
26
+ return {
27
+ 'accuracy': acc,
28
+ 'f1_score': f1,
29
+ 'mcc': mcc,
30
+ 'auc_roc': auc_roc,
31
+ 'balanced_accuracy': balanced_acc,
32
+ 'confusion_matrix': cm.tolist()
33
+ }
34
+
35
+ def encode_sequence(sequence, tokenizer, max_length):
36
+ return tokenizer(sequence, truncation=True, padding='max_length', max_length=max_length, return_tensors='pt')
37
+
38
+ def evaluate(model_path, test_file=None, sequence=None):
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model = AutoModelForSequenceClassification.from_pretrained(model_path, ignore_mismatched_sizes=True).to(device)
41
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
42
+
43
+ if sequence:
44
+ inputs = encode_sequence(sequence, tokenizer, tokenizer.model_max_length)
45
+ with torch.no_grad():
46
+ outputs = model(**{k: v.to(device) for k, v in inputs.items()})
47
+ logits = outputs.logits.cpu().numpy()
48
+ print("Single Sequence Prediction:", np.argmax(logits, axis=1))
49
+ return
50
+
51
+ test_df = fasta_to_df(test_file)
52
+ label_map = {
53
+ 'non-pathogen': 0,
54
+ 'pathogen': 1
55
+ }
56
+ test_df['label'] = test_df['label'].str.lower().map(label_map)
57
+ dataset = Dataset.from_pandas(test_df)
58
+ dataset = dataset.map(lambda x: encode_sequence(x['sequence'], tokenizer, tokenizer.model_max_length), batched=True)
59
+ dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])
60
+
61
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
62
+ logits_list, labels_list = [], []
63
+
64
+ model.eval()
65
+ with torch.no_grad():
66
+ for batch in dataloader:
67
+ inputs = {k: v.to(device) for k, v in batch.items() if k != 'label'}
68
+ # print(type(batch['label']), batch['label'])
69
+ labels = np.array(batch['label'])
70
+ outputs = model(**inputs)
71
+ logits_list.append(outputs.logits.cpu().numpy())
72
+ labels_list.append(labels)
73
+
74
+ logits = np.concatenate(logits_list, axis=0)
75
+ labels = np.concatenate(labels_list, axis=0)
76
+ results = compute_metrics(logits, labels)
77
+ print("Evaluation Metrics:", results)
78
+
79
+ if __name__ == "__main__":
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the fine-tuned model directory")
82
+ parser.add_argument("--test_file", type=str, help="Path to the test fasta file")
83
+ parser.add_argument("--sequence", type=str, help="Single DNA sequence to classify")
84
+ args = parser.parse_args()
85
+ evaluate(args.model_path, args.test_file, args.sequence)
output.log ADDED
@@ -0,0 +1 @@
 
 
1
+ Evaluation Metrics: {'accuracy': 0.6747420367877972, 'f1_score': 0.6143408107901266, 'mcc': 0.3983982161600102, 'auc_roc': 0.26562839560775575, 'balanced_accuracy': 0.6260618867430012, 'confusion_matrix': [[736, 2171], [4, 3776]]}
pathoLM.png ADDED

Git LFS Details

  • SHA256: bcb379dbbcd17e4a8c71af7a36f26b32398bf4af511043f0dffbab5fef36ee0e
  • Pointer size: 131 Bytes
  • Size of remote file: 312 kB
utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from Bio.Seq import Seq
3
+ from Bio.SeqRecord import SeqRecord
4
+ from Bio import SeqIO
5
+
6
+ def stratified_sampling(df, sample_size=5000):
7
+ label_counts = df['label'].value_counts()
8
+ min_count = label_counts.min()
9
+ sample_size = min(sample_size, min_count)
10
+ sampled_df = df.groupby('label').apply(lambda x: x.sample(n=sample_size, random_state=42)).reset_index(drop=True)
11
+ return sampled_df
12
+
13
+ def fasta_to_df(fasta_file):
14
+ unique_ids = []
15
+ species = []
16
+ sequence_lengths = []
17
+ labels = []
18
+ fragment_ids = []
19
+ sequences = []
20
+
21
+ for record in SeqIO.parse(fasta_file, "fasta"):
22
+ unique_ids.append(record.description.split(' ')[0])
23
+
24
+ desc_parts = record.description.split(' ', 1)[1] if ' ' in record.description else ''
25
+ try:
26
+ desc_parts_dict = {part.split(':')[0].strip(): part.split(':')[1].strip() for part in desc_parts.split('|')}
27
+ except Exception as e:
28
+ print(f"Error parsing description for record {record.id}: {e}")
29
+ continue
30
+
31
+ species.append(desc_parts_dict.get('species'))
32
+ sequence_lengths.append(int(desc_parts_dict.get('sequence_length', 0)))
33
+ labels.append(desc_parts_dict.get('label'))
34
+ sequences.append(str(record.seq))
35
+
36
+
37
+ df = pd.DataFrame({
38
+ 'unique_id': unique_ids,
39
+ 'species': species,
40
+ 'sequence_length': sequence_lengths,
41
+ 'label': labels,
42
+ 'sequence': sequences
43
+ })
44
+
45
+ return df