asigalov61's picture
0c2c10c verified
history blame
9.47 kB
import os
import hashlib
import time
import datetime
from pytz import timezone
import copy
from collections import Counter
import random
import statistics
import re
import gradio as gr
from huggingface_hub import InferenceClient
from datasets import load_dataset
import TMIDIX
HF_TOKEN = os.getenv('HF_TOKEN')
def format_table_data(data_string):
# Split the string into rows based on newlines
rows = data_string.strip().split("\n")
# Initialize a list to store the formatted data
formatted_data = []
for row in rows:
# Split each row into columns based on the separator '|' and strip extra spaces
columns = row.split("|")
formatted_row = [cell.strip() for cell in columns]
# Remove cells with only "-" symbols
formatted_row = [cell for cell in formatted_row if not all(char == '-' for char in cell)]
# Handle uneven rows by ensuring each row has the same number of columns
max_columns = max(len(columns) for columns in formatted_data) if formatted_data else len(columns)
while len(formatted_row) < max_columns:
formatted_row.append("") # Add empty strings to fill the row
# Handle case where new rows have more columns than previous rows
max_columns = max(len(row) for row in formatted_data)
for row in formatted_data:
while len(row) < max_columns:
row.append("") # Add empty strings to fill the row
return formatted_data
MODELS = {'Mistral Nemo Instruct 2407': 'mistralai/Mistral-Nemo-Instruct-2407'
def ID_MIDI(input_midi, input_model):
print('*' * 70)
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(
start_time = time.time()
print('=' * 70)
print('Loading MIDI...')
fn = os.path.basename(input_midi)
fn1 = fn.split('.')[0]
fdata = open(input_midi, 'rb').read()
input_midi_md5hash = hashlib.md5(fdata).hexdigest()
print('=' * 70)
print('Requested settings:')
print('=' * 70)
print('Input MIDI file name:', fn)
print('Input MIDI md5 hash', input_midi_md5hash)
print('Input model:', input_model)
print('=' * 70)
print('Processing MIDI...Please wait...')
new_midi_data = TMIDIX.score2midi(TMIDIX.midi2score(fdata))
new_midi_md5hash = hashlib.md5(new_midi_data).hexdigest()
print('New md5 hash:', new_midi_md5hash)
print('=' * 70)
print('Processing...Please wait...')
output_str = 'None'
output_midi_records_count = 0
output_midi_src_dataset= 'Unknown'
output_midi_path_str = 'None'
raw_score = TMIDIX.midi2single_track_ms_score(fdata)
escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, sort_drums_last=True)
output_midi_src_dataset = 'unknown'
output_midi_path_str = 'none'
if new_midi_md5hash in monster_midi_titles['md5_hashes_titles_dict']:
title = random.choice(monster_midi_titles['md5_hashes_titles_dict'][new_midi_md5hash]).split(' --- ')
song = title[0]
artist = title[1]
song_description = TMIDIX.escore_notes_to_text_description(escore_notes, song_name=song, artist_name=artist)
song_description = TMIDIX.escore_notes_to_text_description(escore_notes)
if new_midi_md5hash in midid_md5_hashes:
midid_entry_idx = midid_md5_hashes.index(new_midi_md5hash)
MIDID_record = midid_dataset[midid_entry_idx]['midid']
output_midi_records_count = len(MIDID_record)
output_entry = random.choice(MIDID_record)
output_midi_src_dataset = output_entry[0]
output_midi_path_str = TMIDIX.clean_string(output_entry[1], regex=r'[^a-zA-Z0-9.() \n]')
client = InferenceClient(api_key=HF_TOKEN)
prompt = "Please create a summary table for a MIDI file based on the following keywords strings, best possible description and best possible summary fields. Please respond with the table only. Do not say anything else. Thank you."
data = 'Source MIDI dataset: ' + output_midi_src_dataset + '\n\n'
data += 'MIDI keywords strings:' + '\n'
data += output_midi_path_str + '\n\n'
data += 'Music description:' + '\n'
data += song_description
messages = [
"role": "user",
"content": prompt + "\n\n" + data
completion =
output_str = completion.choices[0].message['content']
output_table_data = format_table_data(output_str)
print('=' * 70)
print('Original MIDI unique records count', output_midi_records_count)
print('Original MIDI dataset', output_midi_src_dataset)
print('Original MIDI path string', data)
print('=' * 70)
output_midi_md5 = str(new_midi_md5hash)
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(
print('-' * 70)
print('Req execution time:', (time.time() - start_time), 'sec')
print('*' * 70)
return output_midi_md5, output_midi_records_count, output_midi_src_dataset, data, output_table_data
if __name__ == "__main__":
PDT = timezone('US/Pacific')
print('=' * 70)
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(
print('=' * 70)
print('Loading MIDID database...')
midid_dataset = load_dataset("asigalov61/MIDID")['train']
midid_md5_hashes = midid_dataset['midi_hash']
print('=' * 70)
print('Loading Monster MIDI titles database...')
monster_midi_titles = TMIDIX.Tegridy_Any_Pickle_File_Reader('Monster_MIDI_Titles_Database_CC_BY_NC_SA.pickle')
print('=' * 70)
app = gr.Blocks()
with app:
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Identification</h1>")
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Identify any MIDI in a comprehensive database of 2.32M+ MIDI records</h1>")
gr.Markdown("This is a demo for tegridy-tools, MIDID and Monster MIDI dataset\n\n"
"Please see [tegridy-tools](, [MIDID]( and [Monster MIDI Dataset]( repos for more information\n\n"
gr.Markdown("## Upload your MIDI")
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath")
input_model = gr.Dropdown(['Mistral Nemo Instruct 2407', 'Mistral Nemo Instruct 2407'],
value='Mistral Nemo Instruct 2407',
label='Select model'
submit = gr.Button("Identify MIDI", variant="primary")
gr.Markdown("## MIDI identification results")
output_midi_md5 = gr.Textbox(label="Monster MIDI dataset md5 hash")
output_midi_records_count = gr.Textbox(label="Original MIDI unique records count")
output_midi_src_dataset = gr.Textbox(label="Original MIDI dataset pretty name")
output_midi_path_str = gr.Textbox(label="Original MIDI raw path string")
output_MIDID_results_table = gr.Dataframe(label="MIDID database results table", wrap=True, col_count=(3, 'dynamic'))
run_event =, [input_midi,