Qa5im commited on
Commit
736e98d
·
1 Parent(s): b5eb12e

added the server mod

Browse files
Files changed (1) hide show
  1. server.py +121 -0
server.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Body, Form
2
+ from pathlib import Path
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from typing import List
5
+ import numpy as np
6
+ from resemblyzer import preprocess_wav, VoiceEncoder
7
+ from itertools import groupby
8
+ from pathlib import Path
9
+ from tqdm import tqdm
10
+ import os
11
+ from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
12
+ import glob
13
+
14
+
15
+ UPLOAD_DIR = Path() / "uploads"
16
+ UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
17
+
18
+ app = FastAPI()
19
+
20
+
21
+ # Add a CORS middleware to allow cross-origin requests from the frontend
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_methods=["*"],
26
+ allow_headers=["*"],
27
+ )
28
+
29
+ # del all files in uploads folder
30
+
31
+
32
+ def delFiles():
33
+ files = glob.glob('uploads/*')
34
+ for f in files:
35
+ os.remove(f)
36
+
37
+ # main function which returns the name of person which has highest similarity index with test audio
38
+
39
+
40
+ async def predictor(names, file_uploads, usersNum, recordingsNum):
41
+ speaker_embed_list = []
42
+ encoder = VoiceEncoder()
43
+ # Iterating over list of files corresponding to each user
44
+ speaker_wavs_list = []
45
+ fileInd = 0
46
+ names.pop() # to remove key named "test"
47
+ for name in names:
48
+ wav_fpaths = []
49
+ for ind in range(int(recordingsNum)):
50
+ file_upload = file_uploads[fileInd]
51
+ data = await file_upload.read()
52
+ # appending person's name to the his/her recordings
53
+ filename = name+"¬"+file_upload.filename
54
+ file_path = UPLOAD_DIR / filename
55
+ with open(file_path, "wb") as file_object:
56
+ file_object.write(data)
57
+ wav_fpaths.append(Path(file_path))
58
+ fileInd += 1
59
+ try:
60
+ speaker_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in
61
+ groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"),
62
+ lambda wav_fpath: os.path.basename(wav_fpath).split("¬")[0])} # extracting person's name from file name
63
+ speaker_wavs_list.append(speaker_wavs)
64
+ except Exception as e:
65
+ print("error ", e)
66
+
67
+ # make a list of the pre-processed audios ki arrays
68
+ for sp_wvs in speaker_wavs_list:
69
+ speaker_embed_list.append(
70
+ np.array([encoder.embed_speaker(wavs) for wavs in sp_wvs.values()]))
71
+
72
+
73
+ # making preprocessed test audio
74
+ wav_fpaths = []
75
+ file_upload = file_uploads[-1]
76
+ data = await file_upload.read()
77
+ filename = "test¬"+file_upload.filename
78
+ file_path = UPLOAD_DIR / filename
79
+ with open(file_path, "wb") as file_object:
80
+ file_object.write(data)
81
+ wav_fpaths.append(Path(file_path))
82
+ test_pos_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in
83
+ groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"),
84
+ lambda wav_fpath: "test")}
85
+ test_pos_emb = np.array([encoder.embed_speaker(wavs)
86
+ for wavs in test_pos_wavs.values()])
87
+
88
+ # calculates cosine similarity between the ground truth (test file) and registered audios
89
+ speakers = {}
90
+ val = 0
91
+ for spkr_embd in speaker_embed_list:
92
+ key_val = names[val]
93
+ spkr_sim = cosine_similarity(spkr_embd, test_pos_emb)[0][0]
94
+ speakers[key_val] = spkr_sim
95
+ val += 1
96
+
97
+ norm = [float(i)/sum(speakers.values()) for i in speakers.values()]
98
+ for i in range(len(norm)):
99
+ key_val = names[i]
100
+ speakers[key_val] = norm[i]
101
+
102
+ identified = max(speakers, key=speakers.get)
103
+ print("\nThe identity of the test speaker:\n", identified, "with a similarity with test of",
104
+ speakers[identified]*100, "percent match as compared to all.")
105
+ return identified
106
+
107
+
108
+ # Update the function parameter to use the Body module and media_type
109
+ @app.post("/predict/")
110
+ async def resultGenerator(names: List[str] = Form(...), file_uploads: List[UploadFile] = File(...), usersNum: str = Form(...), recordingsNum: str = Form(...)):
111
+ # equal to 2 because names list is of the form [name1, name2,..., test]
112
+ try:
113
+ if (len(names) <= 2):
114
+ return {"error: ", "Incorrect data provided"}
115
+ else:
116
+ result = await predictor(names, file_uploads, usersNum, recordingsNum)
117
+ print('## Test Audio Belonged To: {}'.format(result))
118
+ delFiles() # to delete all files from backend, used in this identification
119
+ return {"result": result}
120
+ except:
121
+ return {"error": "Server not responding"}