GeminiFan207 commited on
Commit
e75fa15
·
verified ·
1 Parent(s): 21102ed

Create evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +122 -0
evaluate.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import json
4
+ import os
5
+ from transformers import AutoModelForCausalLM, PreTrainedTokenizerFast
6
+ from datasets import Dataset, DatasetDict
7
+
8
+ # Paths (adjust as needed)
9
+ MODEL_DIR = "../base_model" # Directory with config.json and .safetensors
10
+ TOKENIZER_JSON = "../tokenizer.json"
11
+ DATASET_DIR = "../datasets/"
12
+
13
+ # Load configuration (assuming it’s your earlier Mistral or generation config)
14
+ with open("../config.json", "r") as f:
15
+ config = json.load(f)
16
+
17
+ def load_model():
18
+ """Load the model and tokenizer with optimizations."""
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {device}")
21
+
22
+ try:
23
+ tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_JSON)
24
+ if tokenizer.pad_token is None:
25
+ tokenizer.pad_token = tokenizer.eos_token
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ MODEL_DIR,
29
+ torch_dtype=torch.bfloat16, # From your training
30
+ device_map="auto", # Auto-distribute
31
+ low_cpu_mem_usage=True
32
+ ).to(device)
33
+ return model, tokenizer
34
+ except Exception as e:
35
+ print(f"Error loading model/tokenizer: {e}")
36
+ exit(1)
37
+
38
+ def load_custom_dataset(version):
39
+ """Load Eclipse Corpuz dataset based on version."""
40
+ dataset_path = f"{DATASET_DIR}eclipse_corpuz_{version}.json"
41
+ if not os.path.exists(dataset_path):
42
+ print(f"Error: Dataset {dataset_path} not found")
43
+ exit(1)
44
+
45
+ try:
46
+ with open(dataset_path, "r", encoding="utf-8") as f:
47
+ data = json.load(f)
48
+
49
+ # Handle flexible formats
50
+ if isinstance(data, list):
51
+ # If list of dicts with "text" key
52
+ if data and isinstance(data[0], dict) and "text" in data[0]:
53
+ dataset = Dataset.from_list(data)
54
+ # If list of strings
55
+ else:
56
+ dataset = Dataset.from_dict({"text": data})
57
+ else:
58
+ print(f"Error: Unsupported dataset format in {dataset_path}")
59
+ exit(1)
60
+
61
+ return DatasetDict({"test": dataset})
62
+ except Exception as e:
63
+ print(f"Error loading dataset: {e}")
64
+ exit(1)
65
+
66
+ def evaluate(model, tokenizer, dataset, batch_size=8):
67
+ """Evaluate model on Eclipse Corpuz dataset with batching."""
68
+ dataset = dataset["test"]
69
+ model.eval()
70
+ losses = []
71
+ total_tokens = 0
72
+ correct_tokens = 0
73
+
74
+ # Batch processing
75
+ for i in range(0, min(len(dataset), 100), batch_size): # Limit to 100 samples
76
+ batch = dataset[i:i + batch_size]
77
+ inputs = tokenizer(
78
+ batch["text"],
79
+ return_tensors="pt",
80
+ padding=True,
81
+ truncation=True,
82
+ max_length=config.get("max_length", 512) # From config or default
83
+ ).to(model.device)
84
+
85
+ labels = inputs["input_ids"].clone()
86
+
87
+ with torch.no_grad():
88
+ outputs = model(**inputs, labels=labels)
89
+ losses.append(outputs.loss.item())
90
+
91
+ # Shift logits/labels for next-token prediction accuracy
92
+ shift_logits = outputs.logits[..., :-1, :].contiguous()
93
+ shift_labels = labels[..., 1:].contiguous()
94
+ predictions = torch.argmax(shift_logits, dim=-1)
95
+
96
+ mask = shift_labels != tokenizer.pad_token_id # Ignore padding
97
+ correct_tokens += (predictions == shift_labels).masked_select(mask).sum().item()
98
+ total_tokens += mask.sum().item()
99
+
100
+ avg_loss = sum(losses) / len(losses) if losses else float("inf")
101
+ perplexity = torch.exp(torch.tensor(avg_loss)).item()
102
+ accuracy = correct_tokens / total_tokens if total_tokens > 0 else 0
103
+
104
+ return {"accuracy": accuracy, "loss": avg_loss, "perplexity": perplexity}
105
+
106
+ if __name__ == "__main__":
107
+ parser = argparse.ArgumentParser(description="Evaluate Charm 15 on Eclipse Corpuz dataset")
108
+ parser.add_argument("--version", type=str, default="1.1", help="Dataset version (e.g., 1.1, 1.2)")
109
+ args = parser.parse_args()
110
+
111
+ model, tokenizer = load_model()
112
+ dataset = load_custom_dataset(args.version)
113
+ results = evaluate(model, tokenizer, dataset, batch_size=4) # Lowered for memory
114
+
115
+ print(f"Evaluation Results (Eclipse Corpuz {args.version}):")
116
+ print(f"Accuracy: {results['accuracy']:.4f}")
117
+ print(f"Loss: {results['loss']:.4f}")
118
+ print(f"Perplexity: {results['perplexity']:.4f}")
119
+
120
+ # Cleanup
121
+ del model
122
+ torch.cuda.empty_cache()