Falonne Kpamegan commited on
Commit
732a8f8
·
1 Parent(s): 75bcde3
Files changed (3) hide show
  1. requirements.txt +3 -0
  2. src/model/test.py +51 -0
  3. src/model/train.py +62 -0
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ librosa
src/model/test.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+ from datasets import load_dataset
3
+ from evaluate import load as load_metric
4
+ from transformers import (
5
+ Wav2Vec2ForCTC,
6
+ Wav2Vec2Processor,
7
+ )
8
+ import torch
9
+ import re
10
+ import sys
11
+
12
+ model_name = "facebook/wav2vec2-large-xlsr-53-french"
13
+ device = "cpu"
14
+
15
+ chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"]'
16
+
17
+ model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
18
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
19
+
20
+ ds = load_dataset("facebook/voxpopuli", "fr", trust_remote_code=True)
21
+
22
+ resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)
23
+
24
+ def map_to_array(batch):
25
+ speech, _ = torchaudio.load(batch["path"])
26
+ batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
27
+ batch["sampling_rate"] = resampler.new_freq
28
+ batch["sentence"] = re.sub(chars_to_ignore_regex, '', batch["sentence"]).lower().replace("’", "'")
29
+ return batch
30
+
31
+ ds = ds.map(map_to_array)
32
+
33
+ def map_to_pred(batch):
34
+ features = processor(batch["speech"], sampling_rate=batch["sampling_rate"][0], padding=True, return_tensors="pt")
35
+ input_values = features.input_values.to(device)
36
+ attention_mask = features.attention_mask.to(device)
37
+ with torch.no_grad():
38
+ logits = model(input_values, attention_mask=attention_mask).logits
39
+ pred_ids = torch.argmax(logits, dim=-1)
40
+ batch["predicted"] = processor.batch_decode(pred_ids)
41
+ batch["target"] = batch["sentence"]
42
+ return batch
43
+
44
+ result = ds.map(map_to_pred, batched=True, batch_size=16, remove_columns=list(ds.features.keys()))
45
+
46
+ wer = load_metric("wer")
47
+ wer_score = wer.compute(predictions=result["predicted"], references=result["target"])
48
+ print(f"WER: {wer_score}")
49
+
50
+
51
+ # print(wer.compute(predictions=result["predicted"], references=result["target"]))
src/model/train.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
2
+ import torch
3
+ import librosa
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ # Charger le modèle et le processeur Wav2Vec 2.0
8
+ model_name = "facebook/wav2vec2-large-960h"
9
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
10
+ model = Wav2Vec2ForCTC.from_pretrained(model_name)
11
+
12
+ # Charger l'audio
13
+ audio_file = "path_to_audio_file.wav"
14
+ y, sr = librosa.load(audio_file, sr=16000) # Assurez-vous que le sample rate est 16kHz
15
+
16
+ # Prétraiter l'audio avec le processeur Wav2Vec 2.0
17
+ input_values = processor(y, return_tensors="pt").input_values
18
+
19
+ # Obtenir la prédiction (logits)
20
+ with torch.no_grad():
21
+ logits = model(input_values).logits
22
+
23
+ # Obtenir les IDs des tokens prédits (transcription)
24
+ predicted_ids = torch.argmax(logits, dim=-1)
25
+
26
+ # Décoder les IDs pour obtenir le texte transcrit
27
+ transcription = processor.decode(predicted_ids[0])
28
+
29
+ print("Transcription:", transcription)
30
+
31
+
32
+ # Extraire le pitch (hauteur tonale) et l'intensité
33
+ pitch, magnitudes = librosa.core.piptrack(y=y, sr=sr)
34
+ intensity = librosa.feature.rms(y=y) # Intensité (volume)
35
+
36
+ # Calculer le tempo (vitesse de parole)
37
+ tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
38
+
39
+ # Affichage du pitch
40
+ plt.figure(figsize=(10, 6))
41
+ librosa.display.specshow(pitch, x_axis='time', y_axis='log')
42
+ plt.colorbar()
43
+ plt.title("Pitch (Hauteur Tonale)")
44
+ plt.show()
45
+
46
+ # Affichage de l'intensité
47
+ plt.figure(figsize=(10, 6))
48
+ librosa.display.specshow(intensity, x_axis='time')
49
+ plt.colorbar()
50
+ plt.title("Intensité")
51
+ plt.show()
52
+
53
+ # Fusionner la transcription avec les caractéristiques prosodiques (pitch, intensité, tempo)
54
+ features = np.hstack([
55
+ np.mean(intensity, axis=1), # Moyenne de l'intensité
56
+ np.mean(pitch, axis=1), # Moyenne du pitch
57
+ tempo # Tempo
58
+ ])
59
+
60
+ # Afficher les caractéristiques extraites
61
+ print("Caractéristiques combinées :")
62
+ print(features)