Spaces:
Sleeping
Sleeping
Falonne Kpamegan
commited on
Commit
·
732a8f8
1
Parent(s):
75bcde3
base code
Browse files- requirements.txt +3 -0
- src/model/test.py +51 -0
- src/model/train.py +62 -0
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
torch
|
3 |
+
librosa
|
src/model/test.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torchaudio
|
2 |
+
from datasets import load_dataset
|
3 |
+
from evaluate import load as load_metric
|
4 |
+
from transformers import (
|
5 |
+
Wav2Vec2ForCTC,
|
6 |
+
Wav2Vec2Processor,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
import re
|
10 |
+
import sys
|
11 |
+
|
12 |
+
model_name = "facebook/wav2vec2-large-xlsr-53-french"
|
13 |
+
device = "cpu"
|
14 |
+
|
15 |
+
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
|
16 |
+
|
17 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
|
18 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
19 |
+
|
20 |
+
ds = load_dataset("facebook/voxpopuli", "fr", trust_remote_code=True)
|
21 |
+
|
22 |
+
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
|
23 |
+
|
24 |
+
def map_to_array(batch):
|
25 |
+
speech, _ = torchaudio.load(batch["path"])
|
26 |
+
batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
|
27 |
+
batch["sampling_rate"] = resampler.new_freq
|
28 |
+
batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
|
29 |
+
return batch
|
30 |
+
|
31 |
+
ds = ds.map(map_to_array)
|
32 |
+
|
33 |
+
def map_to_pred(batch):
|
34 |
+
features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
|
35 |
+
input_values = features.input_values.to(device)
|
36 |
+
attention_mask = features.attention_mask.to(device)
|
37 |
+
with torch.no_grad():
|
38 |
+
logits = model(input_values, attention_mask=attention_mask).logits
|
39 |
+
pred_ids = torch.argmax(logits, dim=-1)
|
40 |
+
batch["predicted"] = processor.batch_decode(pred_ids)
|
41 |
+
batch["target"] = batch["sentence"]
|
42 |
+
return batch
|
43 |
+
|
44 |
+
result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
|
45 |
+
|
46 |
+
wer = load_metric("wer")
|
47 |
+
wer_score = wer.compute(predictions=result["predicted"], references=result["target"])
|
48 |
+
print(f"WER: {wer_score}")
|
49 |
+
|
50 |
+
|
51 |
+
# print(wer.compute(predictions=result["predicted"], references=result["target"]))
|
src/model/train.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
2 |
+
import torch
|
3 |
+
import librosa
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
# Charger le modèle et le processeur Wav2Vec 2.0
|
8 |
+
model_name = "facebook/wav2vec2-large-960h"
|
9 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
10 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
11 |
+
|
12 |
+
# Charger l'audio
|
13 |
+
audio_file = "path_to_audio_file.wav"
|
14 |
+
y, sr = librosa.load(audio_file, sr=16000) # Assurez-vous que le sample rate est 16kHz
|
15 |
+
|
16 |
+
# Prétraiter l'audio avec le processeur Wav2Vec 2.0
|
17 |
+
input_values = processor(y, return_tensors="pt").input_values
|
18 |
+
|
19 |
+
# Obtenir la prédiction (logits)
|
20 |
+
with torch.no_grad():
|
21 |
+
logits = model(input_values).logits
|
22 |
+
|
23 |
+
# Obtenir les IDs des tokens prédits (transcription)
|
24 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
25 |
+
|
26 |
+
# Décoder les IDs pour obtenir le texte transcrit
|
27 |
+
transcription = processor.decode(predicted_ids[0])
|
28 |
+
|
29 |
+
print("Transcription:", transcription)
|
30 |
+
|
31 |
+
|
32 |
+
# Extraire le pitch (hauteur tonale) et l'intensité
|
33 |
+
pitch, magnitudes = librosa.core.piptrack(y=y, sr=sr)
|
34 |
+
intensity = librosa.feature.rms(y=y) # Intensité (volume)
|
35 |
+
|
36 |
+
# Calculer le tempo (vitesse de parole)
|
37 |
+
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
|
38 |
+
|
39 |
+
# Affichage du pitch
|
40 |
+
plt.figure(figsize=(10, 6))
|
41 |
+
librosa.display.specshow(pitch, x_axis='time', y_axis='log')
|
42 |
+
plt.colorbar()
|
43 |
+
plt.title("Pitch (Hauteur Tonale)")
|
44 |
+
plt.show()
|
45 |
+
|
46 |
+
# Affichage de l'intensité
|
47 |
+
plt.figure(figsize=(10, 6))
|
48 |
+
librosa.display.specshow(intensity, x_axis='time')
|
49 |
+
plt.colorbar()
|
50 |
+
plt.title("Intensité")
|
51 |
+
plt.show()
|
52 |
+
|
53 |
+
# Fusionner la transcription avec les caractéristiques prosodiques (pitch, intensité, tempo)
|
54 |
+
features = np.hstack([
|
55 |
+
np.mean(intensity, axis=1), # Moyenne de l'intensité
|
56 |
+
np.mean(pitch, axis=1), # Moyenne du pitch
|
57 |
+
tempo # Tempo
|
58 |
+
])
|
59 |
+
|
60 |
+
# Afficher les caractéristiques extraites
|
61 |
+
print("Caractéristiques combinées :")
|
62 |
+
print(features)
|