from fastapi import FastAPI, UploadFile, File, Body, Form from pathlib import Path from fastapi.middleware.cors import CORSMiddleware from typing import List import numpy as np from resemblyzer import preprocess_wav, VoiceEncoder from itertools import groupby from pathlib import Path from tqdm import tqdm import os from sklearn.metrics.pairwise import cosine_similarity, cosine_distances import glob UPLOAD_DIR = Path() / "uploads" UPLOAD_DIR.mkdir(parents=True, exist_ok=True) os.environ[ 'NUMBA_CACHE_DIR' ] = '/tmp/' app = FastAPI() # Add a CORS middleware to allow cross-origin requests from the frontend app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # del all files in uploads folder def delFiles(): files = glob.glob('uploads/*') for f in files: os.remove(f) # main function which returns the name of person which has highest similarity index with test audio async def predictor(names, file_uploads, usersNum, recordingsNum): speaker_embed_list = [] encoder = VoiceEncoder() # Iterating over list of files corresponding to each user speaker_wavs_list = [] fileInd = 0 names.pop() # to remove key named "test" # print("file_uploads ", file_uploads, "recordingNums ", recordingsNum) for name in names: wav_fpaths = [] for ind in range(int(recordingsNum)): print("inside yo") file_upload = file_uploads[fileInd] data = await file_upload.read() # appending person's name to the his/her recordings filename = name+"¬"+file_upload.filename file_path = UPLOAD_DIR / filename with open(file_path, "wb") as file_object: file_object.write(data) wav_fpaths.append(Path(file_path)) fileInd += 1 # print("wav_fpaths len", len(wav_fpaths), "name", name) try: speaker_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"), lambda wav_fpath: os.path.basename(wav_fpath).split("¬")[0])} # extracting person's name from file name speaker_wavs_list.append(speaker_wavs) except Exception as e: print("An exception occurred:", type(error).__name__) print("Exception details:", error) # make a list of the pre-processed audios ki arrays for sp_wvs in speaker_wavs_list: speaker_embed_list.append( np.array([encoder.embed_speaker(wavs) for wavs in sp_wvs.values()])) # print("preprocessed audio ki array ", speaker_embed_list) # making preprocessed test audio wav_fpaths = [] file_upload = file_uploads[-1] data = await file_upload.read() # print("data", data) filename = "test¬"+file_upload.filename file_path = UPLOAD_DIR / filename # print("filepath", file_path) with open(file_path, "wb") as file_object: file_object.write(data) wav_fpaths.append(Path(file_path)) # print("wav_fpath", wav_fpaths) print("about to test") try: test_pos_wavs = {speaker: list(map(preprocess_wav, wav_fpaths)) for speaker, wav_fpaths in groupby(tqdm(wav_fpaths, "Preprocessing wavs", len(wav_fpaths), unit="wavs"), lambda wav_fpath: "test")} # print("test_pos_wavs", test_pos_wavs) except Exception as error: print("An exception occurred:", type(error).__name__) print("Exception details:", error) test_pos_emb = np.array([encoder.embed_speaker(wavs) for wavs in test_pos_wavs.values()]) # calculates cosine similarity between the ground truth (test file) and registered audios speakers = {} val = 0 for spkr_embd in speaker_embed_list: key_val = names[val] spkr_sim = cosine_similarity(spkr_embd, test_pos_emb)[0][0] speakers[key_val] = spkr_sim val += 1 norm = [float(i)/sum(speakers.values()) for i in speakers.values()] for i in range(len(norm)): key_val = names[i] speakers[key_val] = norm[i] identified = max(speakers, key=speakers.get) print("\nThe identity of the test speaker:\n", identified, "with a similarity with test of", speakers[identified]*100, "percent match as compared to all.") return identified # Update the function parameter to use the Body module and media_type @app.post("/predict/") async def resultGenerator(names: List[str] = Form(...), file_uploads: List[UploadFile] = File(...), usersNum: str = Form(...), recordingsNum: str = Form(...)): # equal to 2 because names list is of the form [name1, name2,..., test] try: if (len(names) <= 2): return {"error: ", "Incorrect data provided"} else: result = await predictor(names, file_uploads, usersNum, recordingsNum) print('## Test Audio Belonged To: {}'.format(result)) delFiles() # to delete all files from backend, used in this identification return {"result": result} except: return {"error": "Server not responding"}