File size: 5,304 Bytes
736e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7fd66f8
736e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa4f9dd
 
 
 
 
 
a9242dd
aa4f9dd
 
 
 
 
 
 
 
 
 
 
 
 
a9242dd
aa4f9dd
 
 
 
 
 
a9242dd
 
736e98d
 
 
 
 
 
a9242dd
736e98d
 
 
 
a9242dd
736e98d
 
a9242dd
736e98d
 
c8fe953
736e98d
a9242dd
 
aa4f9dd
 
 
 
a9242dd
aa4f9dd
0e0eeeb
 
abec1f7
 
736e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from fastapi import FastAPI, UploadFile, File, Body, Form
from pathlib import Path
from fastapi.middleware.cors import CORSMiddleware
from typing import List
import numpy as np
from resemblyzer import preprocess_wav, VoiceEncoder
from itertools import groupby
from pathlib import Path
from tqdm import tqdm
import os
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
import glob


UPLOAD_DIR = Path() / "uploads"
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
os.environ[ 'NUMBA_CACHE_DIR' ] = '/tmp/'

app = FastAPI()


# Add a CORS middleware to allow cross-origin requests from the frontend
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

# del all files in uploads folder


def delFiles():
    files = glob.glob('uploads/*')
    for f in files:
        os.remove(f)

# main function which returns the name of person which has highest similarity index with test audio


async def predictor(names, file_uploads, usersNum, recordingsNum):
    speaker_embed_list = []
    encoder = VoiceEncoder()
    # Iterating over list of files corresponding to each user
    speaker_wavs_list = []
    fileInd = 0
    names.pop()  # to remove key named "test"
    # print("file_uploads ", file_uploads, "recordingNums ", recordingsNum)
    for name in names:
        wav_fpaths = []
        for ind in range(int(recordingsNum)):
            print("inside yo")
            file_upload = file_uploads[fileInd]
            data = await file_upload.read()
            # appending person's name to the his/her recordings
            filename = name+"¬"+file_upload.filename
            file_path = UPLOAD_DIR / filename
            with open(file_path, "wb") as file_object:
                file_object.write(data)
            wav_fpaths.append(Path(file_path))
            fileInd += 1
            # print("wav_fpaths len", len(wav_fpaths), "name", name)
        try:
            speaker_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in
                            groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"),
                                    lambda wav_fpath: os.path.basename(wav_fpath).split("¬")[0])}  # extracting person's name from file name
            speaker_wavs_list.append(speaker_wavs)
        except Exception as e:
            print("An exception occurred:", type(error).__name__)
            print("Exception details:", error)

    # make a list of the pre-processed audios ki arrays
    for sp_wvs in speaker_wavs_list:
        speaker_embed_list.append(
            np.array([encoder.embed_speaker(wavs) for wavs in sp_wvs.values()]))

    # print("preprocessed audio ki array ", speaker_embed_list)
    # making preprocessed test audio
    wav_fpaths = []
    file_upload = file_uploads[-1]
    data = await file_upload.read()
    # print("data", data)
    filename = "test¬"+file_upload.filename
    file_path = UPLOAD_DIR / filename
    # print("filepath", file_path)
    with open(file_path, "wb") as file_object:
        file_object.write(data)
    
    wav_fpaths.append(Path(file_path))
    # print("wav_fpath", wav_fpaths)
    print("about to test")
    try:
        test_pos_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in
                        groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"),
                                lambda wav_fpath: "test")}
        # print("test_pos_wavs", test_pos_wavs)
    except Exception as error:
        print("An exception occurred:", type(error).__name__)
        print("Exception details:", error)
    test_pos_emb = np.array([encoder.embed_speaker(wavs)
                            for wavs in test_pos_wavs.values()])

    # calculates cosine similarity between the ground truth (test file) and registered audios
    speakers = {}
    val = 0
    for spkr_embd in speaker_embed_list:
        key_val = names[val]
        spkr_sim = cosine_similarity(spkr_embd, test_pos_emb)[0][0]
        speakers[key_val] = spkr_sim
        val += 1

    norm = [float(i)/sum(speakers.values()) for i in speakers.values()]
    for i in range(len(norm)):
        key_val = names[i]
        speakers[key_val] = norm[i]

    identified = max(speakers, key=speakers.get)
    print("\nThe identity of the test speaker:\n", identified, "with a similarity with test of",
          speakers[identified]*100, "percent match as compared to all.")
    return identified


# Update the function parameter to use the Body module and media_type
@app.post("/predict/")
async def resultGenerator(names: List[str] = Form(...), file_uploads: List[UploadFile] = File(...), usersNum: str = Form(...), recordingsNum: str = Form(...)):
    # equal to 2 because names list is of the form [name1, name2,..., test]
    try:
        if (len(names) <= 2):
            return {"error: ", "Incorrect data provided"}
        else:
            result = await predictor(names, file_uploads, usersNum, recordingsNum)
            print('## Test Audio Belonged To: {}'.format(result))
            delFiles()  # to delete all files from backend, used in this identification
            return {"result": result}
    except:
        return {"error": "Server not responding"}