Spaces:
Build error
Build error
initial commit
Browse files- app.py +123 -0
- database.json +60 -0
- fine_tune_nlu.py +141 -0
- generate_response.py +8 -0
- intent_recognition.py +87 -0
- main.py +67 -0
- nlu_dataset.json +124 -0
- requirements.txt +171 -0
- results/checkpoint-100/config.json +57 -0
- results/checkpoint-100/model.safetensors +3 -0
- results/checkpoint-100/rng_state.pth +3 -0
- results/checkpoint-100/scheduler.pt +3 -0
- results/checkpoint-100/trainer_state.json +825 -0
- results/checkpoint-100/training_args.bin +3 -0
- test_NLU.py +87 -0
- whisper_stt.py +37 -0
app.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import streamlit as st
|
3 |
+
from whisper_stt import transcribe_audio
|
4 |
+
from intent_recognition import get_intent_and_amount
|
5 |
+
from generate_response import generate_voice_response
|
6 |
+
from test_NLU import get_slots
|
7 |
+
|
8 |
+
DATABASE_PATH = "database.json"
|
9 |
+
|
10 |
+
|
11 |
+
def load_database():
|
12 |
+
try:
|
13 |
+
with open(DATABASE_PATH, "r") as db_file:
|
14 |
+
return json.load(db_file)
|
15 |
+
except FileNotFoundError:
|
16 |
+
return {"requests": []}
|
17 |
+
|
18 |
+
|
19 |
+
def save_to_database(data):
|
20 |
+
with open(DATABASE_PATH, "w") as db_file:
|
21 |
+
json.dump(data, db_file, indent=4)
|
22 |
+
|
23 |
+
|
24 |
+
def handle_request(audio_file):
|
25 |
+
while True:
|
26 |
+
|
27 |
+
text = transcribe_audio(audio_file)
|
28 |
+
|
29 |
+
|
30 |
+
intent_data = get_intent_and_amount(text)
|
31 |
+
intent = intent_data.get("intent")
|
32 |
+
if intent:
|
33 |
+
intent = intent.replace("_", " ").title()
|
34 |
+
amount_data = intent_data.get("amount_data")
|
35 |
+
amount = amount_data.get("amount") if amount_data else None
|
36 |
+
currency = amount_data.get("currency") if amount_data else ""
|
37 |
+
slots = get_slots(text)
|
38 |
+
project_name = slots.get("project_name")
|
39 |
+
project_id = slots.get("project_id")
|
40 |
+
task_id = slots.get("task_id")
|
41 |
+
status = slots.get("status")
|
42 |
+
|
43 |
+
# Ensure mandatory fields are present
|
44 |
+
if not intent or not amount or not project_id:
|
45 |
+
generate_voice_response(
|
46 |
+
"Mandatory fields are missing. Please provide the required information again."
|
47 |
+
)
|
48 |
+
st.warning("Mandatory fields missing. Please try again.")
|
49 |
+
continue
|
50 |
+
|
51 |
+
|
52 |
+
st.write("### Extracted Data")
|
53 |
+
st.text(f"Extracted Text: {text}")
|
54 |
+
st.text(f"Intent: {intent}")
|
55 |
+
st.text(f"Project Name: {project_name}")
|
56 |
+
st.text(f"Project ID: {project_id}")
|
57 |
+
st.text(f"Amount: {amount} {currency}")
|
58 |
+
st.text(f"Task ID: {task_id}")
|
59 |
+
st.text(f"Status: {status}")
|
60 |
+
|
61 |
+
response = (
|
62 |
+
f"You have requested for the task: Intent: {intent}, "
|
63 |
+
f"Project: {project_name}. Project ID: {project_id}. "
|
64 |
+
f"Amount: {amount} {currency}. Task ID: {task_id} and Status: {status}. "
|
65 |
+
"Please confirm by typing your response: Yes or No."
|
66 |
+
)
|
67 |
+
generate_voice_response(response)
|
68 |
+
|
69 |
+
# User confirmation
|
70 |
+
# user_input = st.text_input("Type your response (Yes/No):")
|
71 |
+
user_input = st.text_input("Type 'yes' or 'no':").strip().lower()
|
72 |
+
if user_input.lower() == "yes":
|
73 |
+
request_data = {
|
74 |
+
"project": project_name,
|
75 |
+
"project_id": project_id,
|
76 |
+
"amount": amount,
|
77 |
+
"Intent": intent,
|
78 |
+
"task_id": task_id,
|
79 |
+
"status": status,
|
80 |
+
}
|
81 |
+
|
82 |
+
# Save to database
|
83 |
+
database = load_database()
|
84 |
+
database["requests"].append(request_data)
|
85 |
+
save_to_database(database)
|
86 |
+
|
87 |
+
generate_voice_response(
|
88 |
+
"Thank you for your response, Your request has been confirmed successfully."
|
89 |
+
)
|
90 |
+
st.success("Request confirmed and saved successfully.")
|
91 |
+
st.session_state.reset = True
|
92 |
+
break
|
93 |
+
elif user_input.lower() == "no":
|
94 |
+
generate_voice_response(
|
95 |
+
"Thank you for your response, You have denied the confirmation request."
|
96 |
+
)
|
97 |
+
st.warning("Request denied.")
|
98 |
+
st.session_state.reset = True
|
99 |
+
break
|
100 |
+
# else:
|
101 |
+
# generate_voice_response("You have typed an invalid response.")
|
102 |
+
# st.error("Invalid response. Please try again.")
|
103 |
+
# continue
|
104 |
+
|
105 |
+
|
106 |
+
# Streamlit App
|
107 |
+
st.title("ERP Voice Request Handling AI System-Demo")
|
108 |
+
st.write("Upload an audio file and extract information from the request.")
|
109 |
+
|
110 |
+
# Upload audio file
|
111 |
+
audio_file = st.file_uploader("Upload Audio File", type=["wav", "mp3"])
|
112 |
+
|
113 |
+
if audio_file:
|
114 |
+
st.write("### Processing Audio Input")
|
115 |
+
handle_request(audio_file)
|
116 |
+
|
117 |
+
# Display database records
|
118 |
+
st.write("### Saved Requests in Database")
|
119 |
+
database = load_database()
|
120 |
+
if database["requests"]:
|
121 |
+
st.json(database)
|
122 |
+
else:
|
123 |
+
st.write("No requests found.")
|
database.json
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"requests": [
|
3 |
+
{
|
4 |
+
"project": null,
|
5 |
+
"project_id": "223",
|
6 |
+
"amount": "500",
|
7 |
+
"Intent": "Request Money",
|
8 |
+
"task_id": null,
|
9 |
+
"status": null
|
10 |
+
},
|
11 |
+
{
|
12 |
+
"project": null,
|
13 |
+
"project_id": "223",
|
14 |
+
"amount": "500",
|
15 |
+
"Intent": "Request Money",
|
16 |
+
"task_id": null,
|
17 |
+
"status": null
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"project": null,
|
21 |
+
"project_id": "223",
|
22 |
+
"amount": "500",
|
23 |
+
"Intent": "Request Money",
|
24 |
+
"task_id": null,
|
25 |
+
"status": null
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"project": null,
|
29 |
+
"project_id": "223",
|
30 |
+
"amount": "500",
|
31 |
+
"Intent": "Request Money",
|
32 |
+
"task_id": null,
|
33 |
+
"status": null
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"project": null,
|
37 |
+
"project_id": "223",
|
38 |
+
"amount": "500",
|
39 |
+
"Intent": "Request Money",
|
40 |
+
"task_id": null,
|
41 |
+
"status": null
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"project": null,
|
45 |
+
"project_id": "223",
|
46 |
+
"amount": "500",
|
47 |
+
"Intent": "Request Money",
|
48 |
+
"task_id": null,
|
49 |
+
"status": null
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"project": null,
|
53 |
+
"project_id": "223",
|
54 |
+
"amount": "500",
|
55 |
+
"Intent": "Request Money",
|
56 |
+
"task_id": null,
|
57 |
+
"status": null
|
58 |
+
}
|
59 |
+
]
|
60 |
+
}
|
fine_tune_nlu.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments, DataCollatorForTokenClassification
|
2 |
+
from datasets import DatasetDict, Dataset
|
3 |
+
import json
|
4 |
+
|
5 |
+
|
6 |
+
def preprocess_data1(json_path, tokenizer):
|
7 |
+
with open(json_path, "r") as f:
|
8 |
+
data = json.load(f)["data"]
|
9 |
+
|
10 |
+
tokenized_data = {"input_ids": [], "attention_mask": [], "labels": []}
|
11 |
+
slot_label_map = {"O": 0}
|
12 |
+
label_id = 1
|
13 |
+
|
14 |
+
for intent_data in data:
|
15 |
+
for utterance in intent_data["utterances"]:
|
16 |
+
text = utterance["text"]
|
17 |
+
encoding = tokenizer(
|
18 |
+
text,
|
19 |
+
truncation=True,
|
20 |
+
padding="max_length",
|
21 |
+
max_length=128,
|
22 |
+
return_offsets_mapping=True
|
23 |
+
)
|
24 |
+
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
|
25 |
+
|
26 |
+
# Create slot labels for the tokens
|
27 |
+
slot_labels = ["O"] * len(tokens)
|
28 |
+
for slot, value in utterance["slots"].items():
|
29 |
+
if value != "not specified": # Skip unspecified slots
|
30 |
+
slot_tokens = tokenizer.tokenize(value)
|
31 |
+
for i in range(len(tokens) - len(slot_tokens) + 1):
|
32 |
+
if tokens[i:i + len(slot_tokens)] == slot_tokens:
|
33 |
+
slot_labels[i] = f"B-{slot}"
|
34 |
+
for j in range(1, len(slot_tokens)):
|
35 |
+
slot_labels[i + j] = f"I-{slot}"
|
36 |
+
|
37 |
+
# Map slot labels to IDs
|
38 |
+
for label in slot_labels:
|
39 |
+
if label not in slot_label_map:
|
40 |
+
slot_label_map[label] = label_id
|
41 |
+
label_id += 1
|
42 |
+
|
43 |
+
label_ids = [slot_label_map[label] for label in slot_labels]
|
44 |
+
|
45 |
+
|
46 |
+
tokenized_data["input_ids"].append(encoding["input_ids"])
|
47 |
+
tokenized_data["attention_mask"].append(encoding["attention_mask"])
|
48 |
+
tokenized_data["labels"].append(label_ids)
|
49 |
+
|
50 |
+
|
51 |
+
print("Slot Label Map:", slot_label_map)
|
52 |
+
|
53 |
+
|
54 |
+
dataset = Dataset.from_dict(tokenized_data)
|
55 |
+
return DatasetDict({"train": dataset, "validation": dataset}), slot_label_map
|
56 |
+
|
57 |
+
|
58 |
+
# Update training preprocessing to handle multi-token amount
|
59 |
+
def preprocess_data(json_path, tokenizer):
|
60 |
+
with open(json_path, "r") as f:
|
61 |
+
data = json.load(f)["data"]
|
62 |
+
|
63 |
+
tokenized_data = {"input_ids": [], "attention_mask": [], "labels": []}
|
64 |
+
slot_label_map = {"O": 0}
|
65 |
+
|
66 |
+
for intent_data in data:
|
67 |
+
for utterance in intent_data["utterances"]:
|
68 |
+
text = utterance["text"]
|
69 |
+
encoding = tokenizer(
|
70 |
+
text,
|
71 |
+
truncation=True,
|
72 |
+
padding="max_length",
|
73 |
+
max_length=128,
|
74 |
+
return_offsets_mapping=True
|
75 |
+
)
|
76 |
+
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"])
|
77 |
+
|
78 |
+
|
79 |
+
slot_labels = ["O"] * len(tokens)
|
80 |
+
for slot, value in utterance["slots"].items():
|
81 |
+
if value != "not specified":
|
82 |
+
|
83 |
+
slot_tokens = tokenizer.tokenize(value)
|
84 |
+
for i in range(len(tokens) - len(slot_tokens) + 1):
|
85 |
+
if tokens[i:i + len(slot_tokens)] == slot_tokens:
|
86 |
+
slot_labels[i] = f"B-{slot}"
|
87 |
+
for j in range(1, len(slot_tokens)):
|
88 |
+
slot_labels[i + j] = f"I-{slot}"
|
89 |
+
|
90 |
+
# Map slot labels to IDs
|
91 |
+
for label in slot_labels:
|
92 |
+
if label not in slot_label_map:
|
93 |
+
slot_label_map[label] = label_id
|
94 |
+
label_id += 1
|
95 |
+
|
96 |
+
label_ids = [slot_label_map[label] for label in slot_labels]
|
97 |
+
|
98 |
+
|
99 |
+
tokenized_data["input_ids"].append(encoding["input_ids"])
|
100 |
+
tokenized_data["attention_mask"].append(encoding["attention_mask"])
|
101 |
+
tokenized_data["labels"].append(label_ids)
|
102 |
+
|
103 |
+
|
104 |
+
dataset = Dataset.from_dict(tokenized_data)
|
105 |
+
return DatasetDict({"train": dataset, "validation": dataset}), slot_label_map
|
106 |
+
|
107 |
+
|
108 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
|
109 |
+
|
110 |
+
|
111 |
+
json_path = "nlu_dataset.json"
|
112 |
+
dataset, slot_label_map = preprocess_data(json_path, tokenizer)
|
113 |
+
|
114 |
+
|
115 |
+
model = BertForTokenClassification.from_pretrained(
|
116 |
+
"bert-base-multilingual-cased",
|
117 |
+
num_labels=len(slot_label_map)
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
|
122 |
+
|
123 |
+
|
124 |
+
training_args = TrainingArguments(
|
125 |
+
output_dir="./results",
|
126 |
+
num_train_epochs=100,
|
127 |
+
per_device_train_batch_size=16,
|
128 |
+
per_device_eval_batch_size=16,
|
129 |
+
evaluation_strategy="epoch",
|
130 |
+
logging_dir="./logs",
|
131 |
+
)
|
132 |
+
|
133 |
+
trainer = Trainer(
|
134 |
+
model=model,
|
135 |
+
args=training_args,
|
136 |
+
train_dataset=dataset["train"],
|
137 |
+
eval_dataset=dataset["validation"],
|
138 |
+
data_collator=data_collator
|
139 |
+
)
|
140 |
+
|
141 |
+
trainer.train()
|
generate_response.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pyttsx3
|
2 |
+
|
3 |
+
def generate_voice_response(text: str, lang="en"):
|
4 |
+
engine = pyttsx3.init()
|
5 |
+
engine.setProperty('rate', 150) # Speed
|
6 |
+
engine.setProperty('volume', 1) # (0.0 to 1.0)
|
7 |
+
engine.say(text)
|
8 |
+
engine.runAndWait()
|
intent_recognition.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import re
|
3 |
+
from sentence_transformers import SentenceTransformer, util
|
4 |
+
from typing import Dict, Any
|
5 |
+
|
6 |
+
|
7 |
+
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
8 |
+
|
9 |
+
# Define the Intent Rules
|
10 |
+
dataset = {
|
11 |
+
"data": [
|
12 |
+
{
|
13 |
+
"intent": "request_money",
|
14 |
+
"utterances": [
|
15 |
+
{"text": "I need to request money for project 223 to buy some tools, the amount I need is 500 riyals"},
|
16 |
+
{"text": "Please add a money request for the project Abha University for 300 riyals"},
|
17 |
+
{"text": "I need 1000 riyals for project 445 to purchase some equipment"},
|
18 |
+
{"text": "Can you initiate a money request for project 678 with an amount of 250 riyals for team activities?"},
|
19 |
+
{"text": "Requesting 800 riyals for the project Green Energy for office supplies"}
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"intent": "submit_task",
|
24 |
+
"utterances": [
|
25 |
+
{"text": "I have completed the task 1025, please mark it as done"},
|
26 |
+
{"text": "Mark task 3054 as finished in the system"},
|
27 |
+
{"text": "Task 8899 has been completed, update its status"},
|
28 |
+
{"text": "Please mark task 1122 as done, I just finished it"},
|
29 |
+
{"text": "Set the status of task 4500 to finished"}
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"intent": "get_project_status",
|
34 |
+
"utterances": [
|
35 |
+
{"text": "Can you tell me the status of project 223?"},
|
36 |
+
{"text": "What is the current progress on project Abha University?"},
|
37 |
+
{"text": "I need an update on project 445. What is its status?"},
|
38 |
+
{"text": "Could you check and let me know the status of the Smart City project?"},
|
39 |
+
{"text": "What’s the progress on the renewable energy project?"}
|
40 |
+
]
|
41 |
+
}
|
42 |
+
]
|
43 |
+
}
|
44 |
+
|
45 |
+
def extract_amount_with_context(text: str) -> Dict[str, Any]:
|
46 |
+
"""Extract the amount (in currency) along with the currency term and context using regex."""
|
47 |
+
|
48 |
+
# Adjust the regex to capture the amount and surrounding words
|
49 |
+
match = re.search(r'(\d+)\s*(riyals?|reels?|rils?|reel?|dollars?|money|amount|usd|euro|pounds?)\s*(\w{1,20})?(\w{1,20})?', text.lower())
|
50 |
+
|
51 |
+
if match:
|
52 |
+
# Extract the amount and the currency type
|
53 |
+
amount = match.group(1)
|
54 |
+
currency = match.group(2)
|
55 |
+
additional_info = f"{match.group(3)} {match.group(4)}".strip() if match.group(3) or match.group(4) else None
|
56 |
+
return {"amount": amount, "currency": currency, "context": additional_info}
|
57 |
+
return None
|
58 |
+
|
59 |
+
def get_intent_and_amount(text: str) -> Dict[str, Any]:
|
60 |
+
"""
|
61 |
+
Extract intent and amount (if present) from a given text using a similarity model.
|
62 |
+
"""
|
63 |
+
best_match = None
|
64 |
+
best_score = 0
|
65 |
+
intent = "unknown"
|
66 |
+
amount_data = extract_amount_with_context(text)
|
67 |
+
|
68 |
+
# Now, let's detect the intent from the dataset
|
69 |
+
for intent_data in dataset["data"]:
|
70 |
+
for utterance in intent_data["utterances"]:
|
71 |
+
# Compute similarity
|
72 |
+
similarity_score = util.pytorch_cos_sim(
|
73 |
+
model.encode(text, convert_to_tensor=True),
|
74 |
+
model.encode(utterance["text"], convert_to_tensor=True)
|
75 |
+
).item()
|
76 |
+
|
77 |
+
if similarity_score > best_score:
|
78 |
+
best_score = similarity_score
|
79 |
+
best_match = utterance
|
80 |
+
intent = intent_data["intent"]
|
81 |
+
|
82 |
+
return {"intent": intent, "amount_data": amount_data, "score": best_score}
|
83 |
+
|
84 |
+
# Example test
|
85 |
+
user_text = "Hey, I need to request money for a project name Abha University and id is 123 and the amount is 500 riyals"
|
86 |
+
result = get_intent_and_amount(user_text)
|
87 |
+
print(result)
|
main.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from whisper_stt import transcribe_audio_raw
|
3 |
+
from intent_recognition import get_intent_and_amount
|
4 |
+
from generate_response import generate_voice_response
|
5 |
+
from test_NLU import get_slots
|
6 |
+
|
7 |
+
DATABASE_PATH = "database.json"
|
8 |
+
|
9 |
+
def load_database():
|
10 |
+
with open(DATABASE_PATH, "r") as db_file:
|
11 |
+
return json.load(db_file)
|
12 |
+
|
13 |
+
def save_to_database(data):
|
14 |
+
with open(DATABASE_PATH, "w") as db_file:
|
15 |
+
json.dump(data, db_file, indent=4)
|
16 |
+
|
17 |
+
def handle_request(audio_file):
|
18 |
+
|
19 |
+
text = transcribe_audio_raw(audio_file)
|
20 |
+
|
21 |
+
|
22 |
+
intent_data = get_intent_and_amount(text)
|
23 |
+
intent=intent_data.get('intent')
|
24 |
+
intent=intent.replace("_", " ").title()
|
25 |
+
amount_data=intent_data.get('amount_data')
|
26 |
+
amount=amount_data.get('amount')
|
27 |
+
currency=amount_data.get('currency')
|
28 |
+
slots=get_slots(text)
|
29 |
+
slots['amount']=amount+' '+currency
|
30 |
+
if intent is not None:
|
31 |
+
response=f"You have requested for the task: Intent: {intent}, Project: {slots.get('project_name')}. Project ID: {slots.get('project_id')}. Amount: {slots.get('amount')}. Task ID: {slots.get('task_id')} and Status: {slots.get('status')}. Please Confirm by typing your response: Yes or No: "
|
32 |
+
generate_voice_response(response)
|
33 |
+
user_input=input("Please type your response: Yes or No: ")
|
34 |
+
|
35 |
+
if user_input.lower()=="yes":
|
36 |
+
# Prepare the data to save
|
37 |
+
request_data = {
|
38 |
+
"project": slots.get("project_name"),
|
39 |
+
"project_id": slots.get("project_id"),
|
40 |
+
"amount": amount,
|
41 |
+
"Intent": intent,
|
42 |
+
"task_id": slots.get("task_id"),
|
43 |
+
"status": slots.get("status"),
|
44 |
+
}
|
45 |
+
|
46 |
+
|
47 |
+
database = load_database()
|
48 |
+
database["requests"].append(request_data)
|
49 |
+
save_to_database(database)
|
50 |
+
generate_voice_response("Thank you for your response, Your request has been confirmed successfully.")
|
51 |
+
|
52 |
+
elif user_input.lower()=="no":
|
53 |
+
generate_voice_response("Thank you for your response, You have denied the confirmation request.")
|
54 |
+
else:
|
55 |
+
generate_voice_response("You have typed an invalid response.")
|
56 |
+
else:
|
57 |
+
response = "Sorry, I did not understand your request."
|
58 |
+
generate_voice_response(response)
|
59 |
+
return response
|
60 |
+
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
|
64 |
+
user_audio = "input_audio.wav"
|
65 |
+
# audio_file = open(user_audio, "rb")
|
66 |
+
|
67 |
+
print(handle_request(user_audio))
|
nlu_dataset.json
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"data": [
|
3 |
+
{
|
4 |
+
"intent": "request_money",
|
5 |
+
"utterances": [
|
6 |
+
{
|
7 |
+
"text": "I need to request money for project 223 to buy some tools, the amount I need is 500 riyals",
|
8 |
+
"slots": {
|
9 |
+
"project_id": "223",
|
10 |
+
"reason": "buy some tools",
|
11 |
+
"amount": "500 riyals"
|
12 |
+
}
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"text": "Please add a money request for the project Abha University for 300 riyals",
|
16 |
+
"slots": {
|
17 |
+
"project_name": "Abha University",
|
18 |
+
"reason": "not specified",
|
19 |
+
"amount": "300 riyals"
|
20 |
+
}
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"text": "I need 1000 riyals for project 445 to purchase some equipment",
|
24 |
+
"slots": {
|
25 |
+
"project_id": "445",
|
26 |
+
"reason": "purchase some equipment",
|
27 |
+
"amount": "1000 riyals"
|
28 |
+
}
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"text": "Can you initiate a money request for project 678 with an amount of 250 riyals for team activities?",
|
32 |
+
"slots": {
|
33 |
+
"project_id": "678",
|
34 |
+
"reason": "team activities",
|
35 |
+
"amount": "250 riyals"
|
36 |
+
}
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"text": "Requesting 800 riyals for the project Green Energy for office supplies",
|
40 |
+
"slots": {
|
41 |
+
"project_name": "Green Energy",
|
42 |
+
"reason": "office supplies",
|
43 |
+
"amount": "800 riyals"
|
44 |
+
}
|
45 |
+
}
|
46 |
+
]
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"intent": "submit_task",
|
50 |
+
"utterances": [
|
51 |
+
{
|
52 |
+
"text": "I have completed the task 1025, please mark it as done",
|
53 |
+
"slots": {
|
54 |
+
"task_id": "1025",
|
55 |
+
"status": "completed"
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"text": "Mark task 3054 as finished in the system",
|
60 |
+
"slots": {
|
61 |
+
"task_id": "3054",
|
62 |
+
"status": "finished"
|
63 |
+
}
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"text": "Task 8899 has been completed, update its status",
|
67 |
+
"slots": {
|
68 |
+
"task_id": "8899",
|
69 |
+
"status": "completed"
|
70 |
+
}
|
71 |
+
},
|
72 |
+
{
|
73 |
+
"text": "Please mark task 1122 as done, I just finished it",
|
74 |
+
"slots": {
|
75 |
+
"task_id": "1122",
|
76 |
+
"status": "done"
|
77 |
+
}
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"text": "Set the status of task 4500 to finished",
|
81 |
+
"slots": {
|
82 |
+
"task_id": "4500",
|
83 |
+
"status": "finished"
|
84 |
+
}
|
85 |
+
}
|
86 |
+
]
|
87 |
+
},
|
88 |
+
{
|
89 |
+
"intent": "get_project_status",
|
90 |
+
"utterances": [
|
91 |
+
{
|
92 |
+
"text": "Can you tell me the status of project 223?",
|
93 |
+
"slots": {
|
94 |
+
"project_id": "223"
|
95 |
+
}
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"text": "What is the current progress on project Abha University?",
|
99 |
+
"slots": {
|
100 |
+
"project_name": "Abha University"
|
101 |
+
}
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"text": "I need an update on project 445. What is its status?",
|
105 |
+
"slots": {
|
106 |
+
"project_id": "445"
|
107 |
+
}
|
108 |
+
},
|
109 |
+
{
|
110 |
+
"text": "Could you check and let me know the status of the Smart City project?",
|
111 |
+
"slots": {
|
112 |
+
"project_name": "Smart City"
|
113 |
+
}
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"text": "What’s the progress on the renewable energy project?",
|
117 |
+
"slots": {
|
118 |
+
"project_name": "renewable energy"
|
119 |
+
}
|
120 |
+
}
|
121 |
+
]
|
122 |
+
}
|
123 |
+
]
|
124 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
accelerate==1.2.1
|
3 |
+
aiohappyeyeballs==2.4.4
|
4 |
+
aiohttp==3.11.11
|
5 |
+
aiosignal==1.3.2
|
6 |
+
altair==5.5.0
|
7 |
+
annotated-types==0.7.0
|
8 |
+
anyascii==0.3.2
|
9 |
+
attrs==24.3.0
|
10 |
+
audioread==3.0.1
|
11 |
+
babel==2.16.0
|
12 |
+
bangla==0.0.2
|
13 |
+
blinker==1.9.0
|
14 |
+
blis==1.1.0
|
15 |
+
bnnumerizer==0.0.2
|
16 |
+
bnunicodenormalizer==0.1.7
|
17 |
+
cachetools==5.5.0
|
18 |
+
catalogue==2.0.10
|
19 |
+
certifi==2024.12.14
|
20 |
+
cffi==1.17.1
|
21 |
+
charset-normalizer==3.4.1
|
22 |
+
click==8.1.8
|
23 |
+
cloudpathlib==0.20.0
|
24 |
+
colorama==0.4.6
|
25 |
+
comtypes==1.4.8
|
26 |
+
confection==0.1.5
|
27 |
+
contourpy==1.3.1
|
28 |
+
coqpit==0.0.17
|
29 |
+
cycler==0.12.1
|
30 |
+
cymem==2.0.10
|
31 |
+
Cython==3.0.11
|
32 |
+
datasets==3.2.0
|
33 |
+
dateparser==1.1.8
|
34 |
+
decorator==5.1.1
|
35 |
+
dill==0.3.8
|
36 |
+
docopt==0.6.2
|
37 |
+
einops==0.8.0
|
38 |
+
encodec==0.1.1
|
39 |
+
filelock==3.16.1
|
40 |
+
Flask==3.1.0
|
41 |
+
fonttools==4.55.3
|
42 |
+
frozenlist==1.5.0
|
43 |
+
fsspec==2024.9.0
|
44 |
+
g2pkk==0.1.2
|
45 |
+
gitdb==4.0.11
|
46 |
+
GitPython==3.1.43
|
47 |
+
grpcio==1.68.1
|
48 |
+
gruut==2.2.3
|
49 |
+
gruut-ipa==0.13.0
|
50 |
+
gruut_lang_de==2.0.1
|
51 |
+
gruut_lang_en==2.0.1
|
52 |
+
gruut_lang_es==2.0.1
|
53 |
+
gruut_lang_fr==2.0.2
|
54 |
+
hangul-romanize==0.1.0
|
55 |
+
huggingface-hub==0.27.0
|
56 |
+
idna==3.10
|
57 |
+
inflect==7.4.0
|
58 |
+
itsdangerous==2.2.0
|
59 |
+
jamo==0.4.1
|
60 |
+
jieba==0.42.1
|
61 |
+
Jinja2==3.1.5
|
62 |
+
joblib==1.4.2
|
63 |
+
jsonlines==1.2.0
|
64 |
+
jsonschema==4.23.0
|
65 |
+
jsonschema-specifications==2024.10.1
|
66 |
+
kiwisolver==1.4.8
|
67 |
+
langcodes==3.5.0
|
68 |
+
language_data==1.3.0
|
69 |
+
lazy_loader==0.4
|
70 |
+
librosa==0.10.2.post1
|
71 |
+
llvmlite==0.43.0
|
72 |
+
marisa-trie==1.2.1
|
73 |
+
Markdown==3.7
|
74 |
+
markdown-it-py==3.0.0
|
75 |
+
MarkupSafe==3.0.2
|
76 |
+
matplotlib==3.10.0
|
77 |
+
mdurl==0.1.2
|
78 |
+
more-itertools==10.5.0
|
79 |
+
mpmath==1.3.0
|
80 |
+
msgpack==1.1.0
|
81 |
+
multidict==6.1.0
|
82 |
+
multiprocess==0.70.16
|
83 |
+
murmurhash==1.0.11
|
84 |
+
narwhals==1.19.1
|
85 |
+
networkx==2.8.8
|
86 |
+
nltk==3.9.1
|
87 |
+
num2words==0.5.14
|
88 |
+
numba==0.60.0
|
89 |
+
numpy==1.26.4
|
90 |
+
openai-whisper==20240930
|
91 |
+
packaging==24.2
|
92 |
+
pandas==1.5.3
|
93 |
+
pillow==11.0.0
|
94 |
+
platformdirs==4.3.6
|
95 |
+
pooch==1.8.2
|
96 |
+
preshed==3.0.9
|
97 |
+
propcache==0.2.1
|
98 |
+
protobuf==5.29.2
|
99 |
+
psutil==6.1.1
|
100 |
+
pyarrow==18.1.0
|
101 |
+
pycparser==2.22
|
102 |
+
pydantic==2.10.4
|
103 |
+
pydantic_core==2.27.2
|
104 |
+
pydeck==0.9.1
|
105 |
+
pydub==0.25.1
|
106 |
+
Pygments==2.18.0
|
107 |
+
pynndescent==0.5.13
|
108 |
+
pyparsing==3.2.0
|
109 |
+
pypinyin==0.53.0
|
110 |
+
pypiwin32==223
|
111 |
+
pysbd==0.3.4
|
112 |
+
python-crfsuite==0.9.11
|
113 |
+
python-dateutil==2.9.0.post0
|
114 |
+
pyttsx3==2.98
|
115 |
+
pytz==2024.2
|
116 |
+
pywin32==308
|
117 |
+
PyYAML==6.0.2
|
118 |
+
referencing==0.35.1
|
119 |
+
regex==2024.11.6
|
120 |
+
requests==2.32.3
|
121 |
+
rich==13.9.4
|
122 |
+
rpds-py==0.22.3
|
123 |
+
safetensors==0.4.5
|
124 |
+
scikit-learn==1.6.0
|
125 |
+
scipy==1.14.1
|
126 |
+
sentence-transformers==3.3.1
|
127 |
+
shellingham==1.5.4
|
128 |
+
six==1.17.0
|
129 |
+
smart-open==7.1.0
|
130 |
+
smmap==5.0.1
|
131 |
+
soundfile==0.12.1
|
132 |
+
soxr==0.5.0.post1
|
133 |
+
spacy==3.8.3
|
134 |
+
spacy-legacy==3.0.12
|
135 |
+
spacy-loggers==1.0.5
|
136 |
+
srsly==2.5.0
|
137 |
+
streamlit==1.41.1
|
138 |
+
SudachiDict-core==20241021
|
139 |
+
SudachiPy==0.6.9
|
140 |
+
sympy==1.13.1
|
141 |
+
tenacity==9.0.0
|
142 |
+
tensorboard==2.18.0
|
143 |
+
tensorboard-data-server==0.7.2
|
144 |
+
thinc==8.3.3
|
145 |
+
threadpoolctl==3.5.0
|
146 |
+
tiktoken==0.8.0
|
147 |
+
tokenizers==0.21.0
|
148 |
+
toml==0.10.2
|
149 |
+
torch==2.5.1
|
150 |
+
torchaudio==2.5.1
|
151 |
+
tornado==6.4.2
|
152 |
+
tqdm==4.67.1
|
153 |
+
trainer==0.0.36
|
154 |
+
transformers==4.47.1
|
155 |
+
TTS==0.22.0
|
156 |
+
typeguard==4.4.1
|
157 |
+
typer==0.15.1
|
158 |
+
typing_extensions==4.12.2
|
159 |
+
tzdata==2024.2
|
160 |
+
tzlocal==5.2
|
161 |
+
umap-learn==0.5.7
|
162 |
+
Unidecode==1.3.8
|
163 |
+
urllib3==2.3.0
|
164 |
+
wasabi==1.1.3
|
165 |
+
watchdog==6.0.0
|
166 |
+
weasel==0.4.1
|
167 |
+
Werkzeug==3.1.3
|
168 |
+
whisper==1.1.10
|
169 |
+
wrapt==1.17.0
|
170 |
+
xxhash==3.5.0
|
171 |
+
yarl==1.18.3
|
results/checkpoint-100/config.json
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "bert-base-multilingual-cased",
|
3 |
+
"architectures": [
|
4 |
+
"BertForTokenClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0.1,
|
7 |
+
"classifier_dropout": null,
|
8 |
+
"directionality": "bidi",
|
9 |
+
"hidden_act": "gelu",
|
10 |
+
"hidden_dropout_prob": 0.1,
|
11 |
+
"hidden_size": 768,
|
12 |
+
"id2label": {
|
13 |
+
"0": "LABEL_0",
|
14 |
+
"1": "LABEL_1",
|
15 |
+
"2": "LABEL_2",
|
16 |
+
"3": "LABEL_3",
|
17 |
+
"4": "LABEL_4",
|
18 |
+
"5": "LABEL_5",
|
19 |
+
"6": "LABEL_6",
|
20 |
+
"7": "LABEL_7",
|
21 |
+
"8": "LABEL_8",
|
22 |
+
"9": "LABEL_9",
|
23 |
+
"10": "LABEL_10"
|
24 |
+
},
|
25 |
+
"initializer_range": 0.02,
|
26 |
+
"intermediate_size": 3072,
|
27 |
+
"label2id": {
|
28 |
+
"LABEL_0": 0,
|
29 |
+
"LABEL_1": 1,
|
30 |
+
"LABEL_10": 10,
|
31 |
+
"LABEL_2": 2,
|
32 |
+
"LABEL_3": 3,
|
33 |
+
"LABEL_4": 4,
|
34 |
+
"LABEL_5": 5,
|
35 |
+
"LABEL_6": 6,
|
36 |
+
"LABEL_7": 7,
|
37 |
+
"LABEL_8": 8,
|
38 |
+
"LABEL_9": 9
|
39 |
+
},
|
40 |
+
"layer_norm_eps": 1e-12,
|
41 |
+
"max_position_embeddings": 512,
|
42 |
+
"model_type": "bert",
|
43 |
+
"num_attention_heads": 12,
|
44 |
+
"num_hidden_layers": 12,
|
45 |
+
"pad_token_id": 0,
|
46 |
+
"pooler_fc_size": 768,
|
47 |
+
"pooler_num_attention_heads": 12,
|
48 |
+
"pooler_num_fc_layers": 3,
|
49 |
+
"pooler_size_per_head": 128,
|
50 |
+
"pooler_type": "first_token_transform",
|
51 |
+
"position_embedding_type": "absolute",
|
52 |
+
"torch_dtype": "float32",
|
53 |
+
"transformers_version": "4.47.1",
|
54 |
+
"type_vocab_size": 2,
|
55 |
+
"use_cache": true,
|
56 |
+
"vocab_size": 119547
|
57 |
+
}
|
results/checkpoint-100/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3fc34a5672fc68bec86ba7ac93ec8e1f3c5c5c524673048361d6feef7221237
|
3 |
+
size 709108588
|
results/checkpoint-100/rng_state.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5f5db8afdfb30a4f607648049882a8525861563526751d88d52cf6941d75e21b
|
3 |
+
size 13990
|
results/checkpoint-100/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b86863d3b3ac3245ff3a6b271bc0e11992d0ed1afe7e2015bafd85562b8ce01
|
3 |
+
size 1064
|
results/checkpoint-100/trainer_state.json
ADDED
@@ -0,0 +1,825 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"best_metric": null,
|
3 |
+
"best_model_checkpoint": null,
|
4 |
+
"epoch": 100.0,
|
5 |
+
"eval_steps": 500,
|
6 |
+
"global_step": 100,
|
7 |
+
"is_hyper_param_search": false,
|
8 |
+
"is_local_process_zero": true,
|
9 |
+
"is_world_process_zero": true,
|
10 |
+
"log_history": [
|
11 |
+
{
|
12 |
+
"epoch": 1.0,
|
13 |
+
"eval_loss": 0.7868185043334961,
|
14 |
+
"eval_runtime": 1.2903,
|
15 |
+
"eval_samples_per_second": 11.626,
|
16 |
+
"eval_steps_per_second": 0.775,
|
17 |
+
"step": 1
|
18 |
+
},
|
19 |
+
{
|
20 |
+
"epoch": 2.0,
|
21 |
+
"eval_loss": 0.28457802534103394,
|
22 |
+
"eval_runtime": 1.5464,
|
23 |
+
"eval_samples_per_second": 9.7,
|
24 |
+
"eval_steps_per_second": 0.647,
|
25 |
+
"step": 2
|
26 |
+
},
|
27 |
+
{
|
28 |
+
"epoch": 3.0,
|
29 |
+
"eval_loss": 0.21589069068431854,
|
30 |
+
"eval_runtime": 1.4746,
|
31 |
+
"eval_samples_per_second": 10.172,
|
32 |
+
"eval_steps_per_second": 0.678,
|
33 |
+
"step": 3
|
34 |
+
},
|
35 |
+
{
|
36 |
+
"epoch": 4.0,
|
37 |
+
"eval_loss": 0.2054772973060608,
|
38 |
+
"eval_runtime": 1.6254,
|
39 |
+
"eval_samples_per_second": 9.229,
|
40 |
+
"eval_steps_per_second": 0.615,
|
41 |
+
"step": 4
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"epoch": 5.0,
|
45 |
+
"eval_loss": 0.19293639063835144,
|
46 |
+
"eval_runtime": 1.6474,
|
47 |
+
"eval_samples_per_second": 9.105,
|
48 |
+
"eval_steps_per_second": 0.607,
|
49 |
+
"step": 5
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"epoch": 6.0,
|
53 |
+
"eval_loss": 0.17430831491947174,
|
54 |
+
"eval_runtime": 1.8466,
|
55 |
+
"eval_samples_per_second": 8.123,
|
56 |
+
"eval_steps_per_second": 0.542,
|
57 |
+
"step": 6
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"epoch": 7.0,
|
61 |
+
"eval_loss": 0.1566132754087448,
|
62 |
+
"eval_runtime": 1.6633,
|
63 |
+
"eval_samples_per_second": 9.018,
|
64 |
+
"eval_steps_per_second": 0.601,
|
65 |
+
"step": 7
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"epoch": 8.0,
|
69 |
+
"eval_loss": 0.15266196429729462,
|
70 |
+
"eval_runtime": 1.8344,
|
71 |
+
"eval_samples_per_second": 8.177,
|
72 |
+
"eval_steps_per_second": 0.545,
|
73 |
+
"step": 8
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"epoch": 9.0,
|
77 |
+
"eval_loss": 0.13321934640407562,
|
78 |
+
"eval_runtime": 1.7114,
|
79 |
+
"eval_samples_per_second": 8.765,
|
80 |
+
"eval_steps_per_second": 0.584,
|
81 |
+
"step": 9
|
82 |
+
},
|
83 |
+
{
|
84 |
+
"epoch": 10.0,
|
85 |
+
"eval_loss": 0.12868323922157288,
|
86 |
+
"eval_runtime": 1.6882,
|
87 |
+
"eval_samples_per_second": 8.885,
|
88 |
+
"eval_steps_per_second": 0.592,
|
89 |
+
"step": 10
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"epoch": 11.0,
|
93 |
+
"eval_loss": 0.1147986426949501,
|
94 |
+
"eval_runtime": 2.0407,
|
95 |
+
"eval_samples_per_second": 7.35,
|
96 |
+
"eval_steps_per_second": 0.49,
|
97 |
+
"step": 11
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"epoch": 12.0,
|
101 |
+
"eval_loss": 0.11238791793584824,
|
102 |
+
"eval_runtime": 1.8299,
|
103 |
+
"eval_samples_per_second": 8.197,
|
104 |
+
"eval_steps_per_second": 0.546,
|
105 |
+
"step": 12
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"epoch": 13.0,
|
109 |
+
"eval_loss": 0.09630943089723587,
|
110 |
+
"eval_runtime": 1.9623,
|
111 |
+
"eval_samples_per_second": 7.644,
|
112 |
+
"eval_steps_per_second": 0.51,
|
113 |
+
"step": 13
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"epoch": 14.0,
|
117 |
+
"eval_loss": 0.0899113267660141,
|
118 |
+
"eval_runtime": 2.0807,
|
119 |
+
"eval_samples_per_second": 7.209,
|
120 |
+
"eval_steps_per_second": 0.481,
|
121 |
+
"step": 14
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"epoch": 15.0,
|
125 |
+
"eval_loss": 0.0796389952301979,
|
126 |
+
"eval_runtime": 2.043,
|
127 |
+
"eval_samples_per_second": 7.342,
|
128 |
+
"eval_steps_per_second": 0.489,
|
129 |
+
"step": 15
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"epoch": 16.0,
|
133 |
+
"eval_loss": 0.07456444948911667,
|
134 |
+
"eval_runtime": 2.0972,
|
135 |
+
"eval_samples_per_second": 7.152,
|
136 |
+
"eval_steps_per_second": 0.477,
|
137 |
+
"step": 16
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"epoch": 17.0,
|
141 |
+
"eval_loss": 0.0698675587773323,
|
142 |
+
"eval_runtime": 2.3192,
|
143 |
+
"eval_samples_per_second": 6.468,
|
144 |
+
"eval_steps_per_second": 0.431,
|
145 |
+
"step": 17
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"epoch": 18.0,
|
149 |
+
"eval_loss": 0.06313543021678925,
|
150 |
+
"eval_runtime": 2.0522,
|
151 |
+
"eval_samples_per_second": 7.309,
|
152 |
+
"eval_steps_per_second": 0.487,
|
153 |
+
"step": 18
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"epoch": 19.0,
|
157 |
+
"eval_loss": 0.05887909233570099,
|
158 |
+
"eval_runtime": 1.9907,
|
159 |
+
"eval_samples_per_second": 7.535,
|
160 |
+
"eval_steps_per_second": 0.502,
|
161 |
+
"step": 19
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"epoch": 20.0,
|
165 |
+
"eval_loss": 0.0551617294549942,
|
166 |
+
"eval_runtime": 2.0302,
|
167 |
+
"eval_samples_per_second": 7.388,
|
168 |
+
"eval_steps_per_second": 0.493,
|
169 |
+
"step": 20
|
170 |
+
},
|
171 |
+
{
|
172 |
+
"epoch": 21.0,
|
173 |
+
"eval_loss": 0.0511007234454155,
|
174 |
+
"eval_runtime": 2.2652,
|
175 |
+
"eval_samples_per_second": 6.622,
|
176 |
+
"eval_steps_per_second": 0.441,
|
177 |
+
"step": 21
|
178 |
+
},
|
179 |
+
{
|
180 |
+
"epoch": 22.0,
|
181 |
+
"eval_loss": 0.04705721512436867,
|
182 |
+
"eval_runtime": 2.3722,
|
183 |
+
"eval_samples_per_second": 6.323,
|
184 |
+
"eval_steps_per_second": 0.422,
|
185 |
+
"step": 22
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"epoch": 23.0,
|
189 |
+
"eval_loss": 0.0431244932115078,
|
190 |
+
"eval_runtime": 2.2034,
|
191 |
+
"eval_samples_per_second": 6.808,
|
192 |
+
"eval_steps_per_second": 0.454,
|
193 |
+
"step": 23
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"epoch": 24.0,
|
197 |
+
"eval_loss": 0.039149921387434006,
|
198 |
+
"eval_runtime": 2.1518,
|
199 |
+
"eval_samples_per_second": 6.971,
|
200 |
+
"eval_steps_per_second": 0.465,
|
201 |
+
"step": 24
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"epoch": 25.0,
|
205 |
+
"eval_loss": 0.03541847690939903,
|
206 |
+
"eval_runtime": 2.111,
|
207 |
+
"eval_samples_per_second": 7.106,
|
208 |
+
"eval_steps_per_second": 0.474,
|
209 |
+
"step": 25
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"epoch": 26.0,
|
213 |
+
"eval_loss": 0.03274580463767052,
|
214 |
+
"eval_runtime": 2.0862,
|
215 |
+
"eval_samples_per_second": 7.19,
|
216 |
+
"eval_steps_per_second": 0.479,
|
217 |
+
"step": 26
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"epoch": 27.0,
|
221 |
+
"eval_loss": 0.030214540660381317,
|
222 |
+
"eval_runtime": 2.1022,
|
223 |
+
"eval_samples_per_second": 7.135,
|
224 |
+
"eval_steps_per_second": 0.476,
|
225 |
+
"step": 27
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"epoch": 28.0,
|
229 |
+
"eval_loss": 0.02784493751823902,
|
230 |
+
"eval_runtime": 2.1074,
|
231 |
+
"eval_samples_per_second": 7.118,
|
232 |
+
"eval_steps_per_second": 0.475,
|
233 |
+
"step": 28
|
234 |
+
},
|
235 |
+
{
|
236 |
+
"epoch": 29.0,
|
237 |
+
"eval_loss": 0.02551179751753807,
|
238 |
+
"eval_runtime": 2.2566,
|
239 |
+
"eval_samples_per_second": 6.647,
|
240 |
+
"eval_steps_per_second": 0.443,
|
241 |
+
"step": 29
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"epoch": 30.0,
|
245 |
+
"eval_loss": 0.02329176291823387,
|
246 |
+
"eval_runtime": 2.3476,
|
247 |
+
"eval_samples_per_second": 6.39,
|
248 |
+
"eval_steps_per_second": 0.426,
|
249 |
+
"step": 30
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"epoch": 31.0,
|
253 |
+
"eval_loss": 0.02115248702466488,
|
254 |
+
"eval_runtime": 2.3423,
|
255 |
+
"eval_samples_per_second": 6.404,
|
256 |
+
"eval_steps_per_second": 0.427,
|
257 |
+
"step": 31
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"epoch": 32.0,
|
261 |
+
"eval_loss": 0.01917836256325245,
|
262 |
+
"eval_runtime": 2.44,
|
263 |
+
"eval_samples_per_second": 6.148,
|
264 |
+
"eval_steps_per_second": 0.41,
|
265 |
+
"step": 32
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"epoch": 33.0,
|
269 |
+
"eval_loss": 0.017496634274721146,
|
270 |
+
"eval_runtime": 2.3798,
|
271 |
+
"eval_samples_per_second": 6.303,
|
272 |
+
"eval_steps_per_second": 0.42,
|
273 |
+
"step": 33
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"epoch": 34.0,
|
277 |
+
"eval_loss": 0.016098586842417717,
|
278 |
+
"eval_runtime": 2.8157,
|
279 |
+
"eval_samples_per_second": 5.327,
|
280 |
+
"eval_steps_per_second": 0.355,
|
281 |
+
"step": 34
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"epoch": 35.0,
|
285 |
+
"eval_loss": 0.014923782087862492,
|
286 |
+
"eval_runtime": 2.455,
|
287 |
+
"eval_samples_per_second": 6.11,
|
288 |
+
"eval_steps_per_second": 0.407,
|
289 |
+
"step": 35
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"epoch": 36.0,
|
293 |
+
"eval_loss": 0.013880550861358643,
|
294 |
+
"eval_runtime": 2.6382,
|
295 |
+
"eval_samples_per_second": 5.686,
|
296 |
+
"eval_steps_per_second": 0.379,
|
297 |
+
"step": 36
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"epoch": 37.0,
|
301 |
+
"eval_loss": 0.012886795215308666,
|
302 |
+
"eval_runtime": 2.8315,
|
303 |
+
"eval_samples_per_second": 5.298,
|
304 |
+
"eval_steps_per_second": 0.353,
|
305 |
+
"step": 37
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"epoch": 38.0,
|
309 |
+
"eval_loss": 0.012055573984980583,
|
310 |
+
"eval_runtime": 2.8786,
|
311 |
+
"eval_samples_per_second": 5.211,
|
312 |
+
"eval_steps_per_second": 0.347,
|
313 |
+
"step": 38
|
314 |
+
},
|
315 |
+
{
|
316 |
+
"epoch": 39.0,
|
317 |
+
"eval_loss": 0.011289969086647034,
|
318 |
+
"eval_runtime": 3.6692,
|
319 |
+
"eval_samples_per_second": 4.088,
|
320 |
+
"eval_steps_per_second": 0.273,
|
321 |
+
"step": 39
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"epoch": 40.0,
|
325 |
+
"eval_loss": 0.010605846531689167,
|
326 |
+
"eval_runtime": 3.3063,
|
327 |
+
"eval_samples_per_second": 4.537,
|
328 |
+
"eval_steps_per_second": 0.302,
|
329 |
+
"step": 40
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"epoch": 41.0,
|
333 |
+
"eval_loss": 0.00994051992893219,
|
334 |
+
"eval_runtime": 3.0948,
|
335 |
+
"eval_samples_per_second": 4.847,
|
336 |
+
"eval_steps_per_second": 0.323,
|
337 |
+
"step": 41
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"epoch": 42.0,
|
341 |
+
"eval_loss": 0.009244030341506004,
|
342 |
+
"eval_runtime": 3.2056,
|
343 |
+
"eval_samples_per_second": 4.679,
|
344 |
+
"eval_steps_per_second": 0.312,
|
345 |
+
"step": 42
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"epoch": 43.0,
|
349 |
+
"eval_loss": 0.008599556051194668,
|
350 |
+
"eval_runtime": 3.1486,
|
351 |
+
"eval_samples_per_second": 4.764,
|
352 |
+
"eval_steps_per_second": 0.318,
|
353 |
+
"step": 43
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"epoch": 44.0,
|
357 |
+
"eval_loss": 0.007965602912008762,
|
358 |
+
"eval_runtime": 3.2206,
|
359 |
+
"eval_samples_per_second": 4.658,
|
360 |
+
"eval_steps_per_second": 0.311,
|
361 |
+
"step": 44
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"epoch": 45.0,
|
365 |
+
"eval_loss": 0.007378988899290562,
|
366 |
+
"eval_runtime": 3.3572,
|
367 |
+
"eval_samples_per_second": 4.468,
|
368 |
+
"eval_steps_per_second": 0.298,
|
369 |
+
"step": 45
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"epoch": 46.0,
|
373 |
+
"eval_loss": 0.0068125114776194096,
|
374 |
+
"eval_runtime": 3.2586,
|
375 |
+
"eval_samples_per_second": 4.603,
|
376 |
+
"eval_steps_per_second": 0.307,
|
377 |
+
"step": 46
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"epoch": 47.0,
|
381 |
+
"eval_loss": 0.006272478960454464,
|
382 |
+
"eval_runtime": 3.2573,
|
383 |
+
"eval_samples_per_second": 4.605,
|
384 |
+
"eval_steps_per_second": 0.307,
|
385 |
+
"step": 47
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"epoch": 48.0,
|
389 |
+
"eval_loss": 0.005782509222626686,
|
390 |
+
"eval_runtime": 3.5423,
|
391 |
+
"eval_samples_per_second": 4.235,
|
392 |
+
"eval_steps_per_second": 0.282,
|
393 |
+
"step": 48
|
394 |
+
},
|
395 |
+
{
|
396 |
+
"epoch": 49.0,
|
397 |
+
"eval_loss": 0.005347794853150845,
|
398 |
+
"eval_runtime": 3.2104,
|
399 |
+
"eval_samples_per_second": 4.672,
|
400 |
+
"eval_steps_per_second": 0.311,
|
401 |
+
"step": 49
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"epoch": 50.0,
|
405 |
+
"eval_loss": 0.004948154091835022,
|
406 |
+
"eval_runtime": 3.2715,
|
407 |
+
"eval_samples_per_second": 4.585,
|
408 |
+
"eval_steps_per_second": 0.306,
|
409 |
+
"step": 50
|
410 |
+
},
|
411 |
+
{
|
412 |
+
"epoch": 51.0,
|
413 |
+
"eval_loss": 0.004586088005453348,
|
414 |
+
"eval_runtime": 3.2838,
|
415 |
+
"eval_samples_per_second": 4.568,
|
416 |
+
"eval_steps_per_second": 0.305,
|
417 |
+
"step": 51
|
418 |
+
},
|
419 |
+
{
|
420 |
+
"epoch": 52.0,
|
421 |
+
"eval_loss": 0.004253142047673464,
|
422 |
+
"eval_runtime": 3.0694,
|
423 |
+
"eval_samples_per_second": 4.887,
|
424 |
+
"eval_steps_per_second": 0.326,
|
425 |
+
"step": 52
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"epoch": 53.0,
|
429 |
+
"eval_loss": 0.003955014981329441,
|
430 |
+
"eval_runtime": 3.3957,
|
431 |
+
"eval_samples_per_second": 4.417,
|
432 |
+
"eval_steps_per_second": 0.294,
|
433 |
+
"step": 53
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"epoch": 54.0,
|
437 |
+
"eval_loss": 0.0036877128295600414,
|
438 |
+
"eval_runtime": 3.0922,
|
439 |
+
"eval_samples_per_second": 4.851,
|
440 |
+
"eval_steps_per_second": 0.323,
|
441 |
+
"step": 54
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"epoch": 55.0,
|
445 |
+
"eval_loss": 0.003441128646954894,
|
446 |
+
"eval_runtime": 3.0702,
|
447 |
+
"eval_samples_per_second": 4.886,
|
448 |
+
"eval_steps_per_second": 0.326,
|
449 |
+
"step": 55
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"epoch": 56.0,
|
453 |
+
"eval_loss": 0.00322518078610301,
|
454 |
+
"eval_runtime": 3.2255,
|
455 |
+
"eval_samples_per_second": 4.65,
|
456 |
+
"eval_steps_per_second": 0.31,
|
457 |
+
"step": 56
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"epoch": 57.0,
|
461 |
+
"eval_loss": 0.003025263315066695,
|
462 |
+
"eval_runtime": 3.2191,
|
463 |
+
"eval_samples_per_second": 4.66,
|
464 |
+
"eval_steps_per_second": 0.311,
|
465 |
+
"step": 57
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"epoch": 58.0,
|
469 |
+
"eval_loss": 0.002828031312674284,
|
470 |
+
"eval_runtime": 3.2217,
|
471 |
+
"eval_samples_per_second": 4.656,
|
472 |
+
"eval_steps_per_second": 0.31,
|
473 |
+
"step": 58
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"epoch": 59.0,
|
477 |
+
"eval_loss": 0.002643935615196824,
|
478 |
+
"eval_runtime": 3.2818,
|
479 |
+
"eval_samples_per_second": 4.571,
|
480 |
+
"eval_steps_per_second": 0.305,
|
481 |
+
"step": 59
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"epoch": 60.0,
|
485 |
+
"eval_loss": 0.0024718190543353558,
|
486 |
+
"eval_runtime": 3.1539,
|
487 |
+
"eval_samples_per_second": 4.756,
|
488 |
+
"eval_steps_per_second": 0.317,
|
489 |
+
"step": 60
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"epoch": 61.0,
|
493 |
+
"eval_loss": 0.0023161820136010647,
|
494 |
+
"eval_runtime": 3.0914,
|
495 |
+
"eval_samples_per_second": 4.852,
|
496 |
+
"eval_steps_per_second": 0.323,
|
497 |
+
"step": 61
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"epoch": 62.0,
|
501 |
+
"eval_loss": 0.0021715254988521338,
|
502 |
+
"eval_runtime": 3.9473,
|
503 |
+
"eval_samples_per_second": 3.8,
|
504 |
+
"eval_steps_per_second": 0.253,
|
505 |
+
"step": 62
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"epoch": 63.0,
|
509 |
+
"eval_loss": 0.002039993414655328,
|
510 |
+
"eval_runtime": 3.0608,
|
511 |
+
"eval_samples_per_second": 4.901,
|
512 |
+
"eval_steps_per_second": 0.327,
|
513 |
+
"step": 63
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"epoch": 64.0,
|
517 |
+
"eval_loss": 0.0019214763306081295,
|
518 |
+
"eval_runtime": 3.5577,
|
519 |
+
"eval_samples_per_second": 4.216,
|
520 |
+
"eval_steps_per_second": 0.281,
|
521 |
+
"step": 64
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"epoch": 65.0,
|
525 |
+
"eval_loss": 0.0018131383694708347,
|
526 |
+
"eval_runtime": 3.6577,
|
527 |
+
"eval_samples_per_second": 4.101,
|
528 |
+
"eval_steps_per_second": 0.273,
|
529 |
+
"step": 65
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"epoch": 66.0,
|
533 |
+
"eval_loss": 0.001715756836347282,
|
534 |
+
"eval_runtime": 3.2888,
|
535 |
+
"eval_samples_per_second": 4.561,
|
536 |
+
"eval_steps_per_second": 0.304,
|
537 |
+
"step": 66
|
538 |
+
},
|
539 |
+
{
|
540 |
+
"epoch": 67.0,
|
541 |
+
"eval_loss": 0.0016270950436592102,
|
542 |
+
"eval_runtime": 3.1551,
|
543 |
+
"eval_samples_per_second": 4.754,
|
544 |
+
"eval_steps_per_second": 0.317,
|
545 |
+
"step": 67
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"epoch": 68.0,
|
549 |
+
"eval_loss": 0.0015468894271180034,
|
550 |
+
"eval_runtime": 3.1791,
|
551 |
+
"eval_samples_per_second": 4.718,
|
552 |
+
"eval_steps_per_second": 0.315,
|
553 |
+
"step": 68
|
554 |
+
},
|
555 |
+
{
|
556 |
+
"epoch": 69.0,
|
557 |
+
"eval_loss": 0.001476285862736404,
|
558 |
+
"eval_runtime": 3.1149,
|
559 |
+
"eval_samples_per_second": 4.816,
|
560 |
+
"eval_steps_per_second": 0.321,
|
561 |
+
"step": 69
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"epoch": 70.0,
|
565 |
+
"eval_loss": 0.0014136920217424631,
|
566 |
+
"eval_runtime": 3.181,
|
567 |
+
"eval_samples_per_second": 4.716,
|
568 |
+
"eval_steps_per_second": 0.314,
|
569 |
+
"step": 70
|
570 |
+
},
|
571 |
+
{
|
572 |
+
"epoch": 71.0,
|
573 |
+
"eval_loss": 0.0013573016040027142,
|
574 |
+
"eval_runtime": 3.9126,
|
575 |
+
"eval_samples_per_second": 3.834,
|
576 |
+
"eval_steps_per_second": 0.256,
|
577 |
+
"step": 71
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"epoch": 72.0,
|
581 |
+
"eval_loss": 0.0013067907420918345,
|
582 |
+
"eval_runtime": 3.1187,
|
583 |
+
"eval_samples_per_second": 4.81,
|
584 |
+
"eval_steps_per_second": 0.321,
|
585 |
+
"step": 72
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"epoch": 73.0,
|
589 |
+
"eval_loss": 0.0012619098415598273,
|
590 |
+
"eval_runtime": 3.3641,
|
591 |
+
"eval_samples_per_second": 4.459,
|
592 |
+
"eval_steps_per_second": 0.297,
|
593 |
+
"step": 73
|
594 |
+
},
|
595 |
+
{
|
596 |
+
"epoch": 74.0,
|
597 |
+
"eval_loss": 0.0012214797316119075,
|
598 |
+
"eval_runtime": 3.3495,
|
599 |
+
"eval_samples_per_second": 4.478,
|
600 |
+
"eval_steps_per_second": 0.299,
|
601 |
+
"step": 74
|
602 |
+
},
|
603 |
+
{
|
604 |
+
"epoch": 75.0,
|
605 |
+
"eval_loss": 0.0011843887623399496,
|
606 |
+
"eval_runtime": 3.1759,
|
607 |
+
"eval_samples_per_second": 4.723,
|
608 |
+
"eval_steps_per_second": 0.315,
|
609 |
+
"step": 75
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"epoch": 76.0,
|
613 |
+
"eval_loss": 0.0011517057428136468,
|
614 |
+
"eval_runtime": 3.8153,
|
615 |
+
"eval_samples_per_second": 3.931,
|
616 |
+
"eval_steps_per_second": 0.262,
|
617 |
+
"step": 76
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"epoch": 77.0,
|
621 |
+
"eval_loss": 0.0011234048288315535,
|
622 |
+
"eval_runtime": 3.3084,
|
623 |
+
"eval_samples_per_second": 4.534,
|
624 |
+
"eval_steps_per_second": 0.302,
|
625 |
+
"step": 77
|
626 |
+
},
|
627 |
+
{
|
628 |
+
"epoch": 78.0,
|
629 |
+
"eval_loss": 0.0010975470067933202,
|
630 |
+
"eval_runtime": 3.1214,
|
631 |
+
"eval_samples_per_second": 4.806,
|
632 |
+
"eval_steps_per_second": 0.32,
|
633 |
+
"step": 78
|
634 |
+
},
|
635 |
+
{
|
636 |
+
"epoch": 79.0,
|
637 |
+
"eval_loss": 0.0010739320423454046,
|
638 |
+
"eval_runtime": 3.145,
|
639 |
+
"eval_samples_per_second": 4.769,
|
640 |
+
"eval_steps_per_second": 0.318,
|
641 |
+
"step": 79
|
642 |
+
},
|
643 |
+
{
|
644 |
+
"epoch": 80.0,
|
645 |
+
"eval_loss": 0.0010527895065024495,
|
646 |
+
"eval_runtime": 3.2234,
|
647 |
+
"eval_samples_per_second": 4.654,
|
648 |
+
"eval_steps_per_second": 0.31,
|
649 |
+
"step": 80
|
650 |
+
},
|
651 |
+
{
|
652 |
+
"epoch": 81.0,
|
653 |
+
"eval_loss": 0.0010336448904126883,
|
654 |
+
"eval_runtime": 3.6565,
|
655 |
+
"eval_samples_per_second": 4.102,
|
656 |
+
"eval_steps_per_second": 0.273,
|
657 |
+
"step": 81
|
658 |
+
},
|
659 |
+
{
|
660 |
+
"epoch": 82.0,
|
661 |
+
"eval_loss": 0.001016051392070949,
|
662 |
+
"eval_runtime": 3.2992,
|
663 |
+
"eval_samples_per_second": 4.547,
|
664 |
+
"eval_steps_per_second": 0.303,
|
665 |
+
"step": 82
|
666 |
+
},
|
667 |
+
{
|
668 |
+
"epoch": 83.0,
|
669 |
+
"eval_loss": 0.0010000885231420398,
|
670 |
+
"eval_runtime": 3.2485,
|
671 |
+
"eval_samples_per_second": 4.618,
|
672 |
+
"eval_steps_per_second": 0.308,
|
673 |
+
"step": 83
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"epoch": 84.0,
|
677 |
+
"eval_loss": 0.000985819729976356,
|
678 |
+
"eval_runtime": 3.4147,
|
679 |
+
"eval_samples_per_second": 4.393,
|
680 |
+
"eval_steps_per_second": 0.293,
|
681 |
+
"step": 84
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"epoch": 85.0,
|
685 |
+
"eval_loss": 0.0009730439051054418,
|
686 |
+
"eval_runtime": 3.3375,
|
687 |
+
"eval_samples_per_second": 4.494,
|
688 |
+
"eval_steps_per_second": 0.3,
|
689 |
+
"step": 85
|
690 |
+
},
|
691 |
+
{
|
692 |
+
"epoch": 86.0,
|
693 |
+
"eval_loss": 0.0009613920701667666,
|
694 |
+
"eval_runtime": 3.2532,
|
695 |
+
"eval_samples_per_second": 4.611,
|
696 |
+
"eval_steps_per_second": 0.307,
|
697 |
+
"step": 86
|
698 |
+
},
|
699 |
+
{
|
700 |
+
"epoch": 87.0,
|
701 |
+
"eval_loss": 0.0009508637012913823,
|
702 |
+
"eval_runtime": 3.3259,
|
703 |
+
"eval_samples_per_second": 4.51,
|
704 |
+
"eval_steps_per_second": 0.301,
|
705 |
+
"step": 87
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"epoch": 88.0,
|
709 |
+
"eval_loss": 0.0009414219530299306,
|
710 |
+
"eval_runtime": 3.1885,
|
711 |
+
"eval_samples_per_second": 4.704,
|
712 |
+
"eval_steps_per_second": 0.314,
|
713 |
+
"step": 88
|
714 |
+
},
|
715 |
+
{
|
716 |
+
"epoch": 89.0,
|
717 |
+
"eval_loss": 0.0009328797459602356,
|
718 |
+
"eval_runtime": 3.9468,
|
719 |
+
"eval_samples_per_second": 3.801,
|
720 |
+
"eval_steps_per_second": 0.253,
|
721 |
+
"step": 89
|
722 |
+
},
|
723 |
+
{
|
724 |
+
"epoch": 90.0,
|
725 |
+
"eval_loss": 0.0009253285243175924,
|
726 |
+
"eval_runtime": 3.344,
|
727 |
+
"eval_samples_per_second": 4.486,
|
728 |
+
"eval_steps_per_second": 0.299,
|
729 |
+
"step": 90
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"epoch": 91.0,
|
733 |
+
"eval_loss": 0.0009186835959553719,
|
734 |
+
"eval_runtime": 3.403,
|
735 |
+
"eval_samples_per_second": 4.408,
|
736 |
+
"eval_steps_per_second": 0.294,
|
737 |
+
"step": 91
|
738 |
+
},
|
739 |
+
{
|
740 |
+
"epoch": 92.0,
|
741 |
+
"eval_loss": 0.0009127946686930954,
|
742 |
+
"eval_runtime": 3.2406,
|
743 |
+
"eval_samples_per_second": 4.629,
|
744 |
+
"eval_steps_per_second": 0.309,
|
745 |
+
"step": 92
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"epoch": 93.0,
|
749 |
+
"eval_loss": 0.0009078615694306791,
|
750 |
+
"eval_runtime": 3.2285,
|
751 |
+
"eval_samples_per_second": 4.646,
|
752 |
+
"eval_steps_per_second": 0.31,
|
753 |
+
"step": 93
|
754 |
+
},
|
755 |
+
{
|
756 |
+
"epoch": 94.0,
|
757 |
+
"eval_loss": 0.0009037015843205154,
|
758 |
+
"eval_runtime": 2.9444,
|
759 |
+
"eval_samples_per_second": 5.094,
|
760 |
+
"eval_steps_per_second": 0.34,
|
761 |
+
"step": 94
|
762 |
+
},
|
763 |
+
{
|
764 |
+
"epoch": 95.0,
|
765 |
+
"eval_loss": 0.0009001877042464912,
|
766 |
+
"eval_runtime": 3.459,
|
767 |
+
"eval_samples_per_second": 4.337,
|
768 |
+
"eval_steps_per_second": 0.289,
|
769 |
+
"step": 95
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"epoch": 96.0,
|
773 |
+
"eval_loss": 0.0008972398354671896,
|
774 |
+
"eval_runtime": 3.2953,
|
775 |
+
"eval_samples_per_second": 4.552,
|
776 |
+
"eval_steps_per_second": 0.303,
|
777 |
+
"step": 96
|
778 |
+
},
|
779 |
+
{
|
780 |
+
"epoch": 97.0,
|
781 |
+
"eval_loss": 0.0008948465110734105,
|
782 |
+
"eval_runtime": 3.238,
|
783 |
+
"eval_samples_per_second": 4.633,
|
784 |
+
"eval_steps_per_second": 0.309,
|
785 |
+
"step": 97
|
786 |
+
},
|
787 |
+
{
|
788 |
+
"epoch": 98.0,
|
789 |
+
"eval_loss": 0.0008930906769819558,
|
790 |
+
"eval_runtime": 3.5433,
|
791 |
+
"eval_samples_per_second": 4.233,
|
792 |
+
"eval_steps_per_second": 0.282,
|
793 |
+
"step": 98
|
794 |
+
},
|
795 |
+
{
|
796 |
+
"epoch": 99.0,
|
797 |
+
"eval_loss": 0.0008919287356548011,
|
798 |
+
"eval_runtime": 3.3619,
|
799 |
+
"eval_samples_per_second": 4.462,
|
800 |
+
"eval_steps_per_second": 0.297,
|
801 |
+
"step": 99
|
802 |
+
}
|
803 |
+
],
|
804 |
+
"logging_steps": 500,
|
805 |
+
"max_steps": 100,
|
806 |
+
"num_input_tokens_seen": 0,
|
807 |
+
"num_train_epochs": 100,
|
808 |
+
"save_steps": 500,
|
809 |
+
"stateful_callbacks": {
|
810 |
+
"TrainerControl": {
|
811 |
+
"args": {
|
812 |
+
"should_epoch_stop": false,
|
813 |
+
"should_evaluate": false,
|
814 |
+
"should_log": false,
|
815 |
+
"should_save": true,
|
816 |
+
"should_training_stop": true
|
817 |
+
},
|
818 |
+
"attributes": {}
|
819 |
+
}
|
820 |
+
},
|
821 |
+
"total_flos": 97994256768000.0,
|
822 |
+
"train_batch_size": 16,
|
823 |
+
"trial_name": null,
|
824 |
+
"trial_params": null
|
825 |
+
}
|
results/checkpoint-100/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4165414f4aaed45dd5f68a45321417b874e253512feaa0d080315683f181b0aa
|
3 |
+
size 5240
|
test_NLU.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import BertTokenizerFast,BertForTokenClassification
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
|
9 |
+
model = BertForTokenClassification.from_pretrained("./results/checkpoint-100")
|
10 |
+
|
11 |
+
|
12 |
+
slot_label_map = {
|
13 |
+
0: "O", 1: "B-project_id", 2: "I-project_id", 3: "B-reason", 4: "I-reason",
|
14 |
+
5: "B-amount", 6: "I-amount", 7: "B-project_name", 8: "I-project_name",
|
15 |
+
9: "B-status", 10: "I-status",11: "B-riyals", 12: "I-riyals"
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def decode_slots(tokens, predictions, slot_label_map):
|
21 |
+
slots = {}
|
22 |
+
current_slot = None
|
23 |
+
current_value = []
|
24 |
+
|
25 |
+
for token, pred_id in zip(tokens, predictions):
|
26 |
+
label = slot_label_map[pred_id]
|
27 |
+
|
28 |
+
# Handle B- and I- slots
|
29 |
+
if label.startswith("B-"): # Beginning of a new slot
|
30 |
+
if current_slot:
|
31 |
+
|
32 |
+
slots[current_slot] = tokenizer.convert_tokens_to_string(current_value)
|
33 |
+
current_slot = label[2:] # Extract slot name
|
34 |
+
current_value = [token] # Start a new slot
|
35 |
+
elif label.startswith("I-") and current_slot == label[2:]: # Continuation of the current slot
|
36 |
+
current_value.append(token)
|
37 |
+
else: # No slot or "O"
|
38 |
+
if current_slot:
|
39 |
+
|
40 |
+
slots[current_slot] = tokenizer.convert_tokens_to_string(current_value)
|
41 |
+
current_slot = None
|
42 |
+
current_value = []
|
43 |
+
|
44 |
+
if current_slot:
|
45 |
+
slots[current_slot] = tokenizer.convert_tokens_to_string(current_value)
|
46 |
+
|
47 |
+
return slots
|
48 |
+
|
49 |
+
|
50 |
+
def predict_intent_and_slots(text, model, tokenizer, slot_label_map):
|
51 |
+
encoding = tokenizer(
|
52 |
+
text,
|
53 |
+
truncation=True,
|
54 |
+
padding="max_length",
|
55 |
+
max_length=128, # Same as during training
|
56 |
+
return_tensors="pt"
|
57 |
+
)
|
58 |
+
input_ids = encoding["input_ids"]
|
59 |
+
attention_mask = encoding["attention_mask"]
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
outputs = model(input_ids, attention_mask=attention_mask)
|
63 |
+
logits = outputs.logits
|
64 |
+
predictions = torch.argmax(logits, dim=2).squeeze().tolist()
|
65 |
+
|
66 |
+
tokens = tokenizer.convert_ids_to_tokens(input_ids.squeeze().tolist())
|
67 |
+
predictions = predictions[:len(tokens)]
|
68 |
+
|
69 |
+
|
70 |
+
slots = decode_slots(tokens, predictions, slot_label_map)
|
71 |
+
|
72 |
+
|
73 |
+
intent = "mock_intent"
|
74 |
+
|
75 |
+
return {"utterance": text, "slots": slots}
|
76 |
+
|
77 |
+
def get_slots(text):
|
78 |
+
result = predict_intent_and_slots(text, model, tokenizer, slot_label_map)
|
79 |
+
slots=result['slots']
|
80 |
+
return slots
|
81 |
+
|
82 |
+
# Test the model
|
83 |
+
test_text = "Hey, I need to request money for a project name Abha University and id is 123 and the amount is 500 riyals"
|
84 |
+
result = predict_intent_and_slots(test_text, model, tokenizer, slot_label_map)
|
85 |
+
|
86 |
+
print("Prediction Result:")
|
87 |
+
print(result)
|
whisper_stt.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import torch
|
5 |
+
from transformers import pipeline
|
6 |
+
|
7 |
+
def transcribe_audio_raw(file_path: str) -> str:
|
8 |
+
# file_path = "C:/Users/Lenovo/ML Notebooks/ERP Assistant/example.wav"
|
9 |
+
# if not os.path.exists(file_path):
|
10 |
+
# print(f"File not found: {file_path}")
|
11 |
+
# else:
|
12 |
+
# print("File found!")
|
13 |
+
# audio_data, sr = librosa.load(file_path, sr=None)
|
14 |
+
whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device="cpu")
|
15 |
+
transcription = whisper_pipe(file_path)
|
16 |
+
print(transcription)
|
17 |
+
|
18 |
+
return transcription['text']
|
19 |
+
|
20 |
+
|
21 |
+
import tempfile
|
22 |
+
def transcribe_audio(uploaded_file):
|
23 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
|
24 |
+
temp_file.write(uploaded_file.read())
|
25 |
+
file_path = temp_file.name
|
26 |
+
|
27 |
+
whisper_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-tiny.en", device="cpu")
|
28 |
+
transcription = whisper_pipe(file_path)
|
29 |
+
return transcription['text']
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
|