Spaces:
Running
Running
#================================================================ | |
# https://huggingface.co/spaces/asigalov61/MIDI-Identification | |
#================================================================ | |
import os | |
import hashlib | |
import time | |
import datetime | |
from pytz import timezone | |
import copy | |
from collections import Counter | |
import random | |
import statistics | |
import gradio as gr | |
from huggingface_hub import InferenceClient | |
import TMIDIX | |
#========================================================================================================== | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
#========================================================================================================== | |
def format_table_data(data_string): | |
# Split the string into rows based on newlines | |
rows = data_string.strip().split("\n") | |
# Initialize a list to store the formatted data | |
formatted_data = [] | |
for row in rows: | |
# Split each row into columns based on the separator '|' and strip extra spaces | |
columns = row.split("|") | |
formatted_row = [cell.strip() for cell in columns] | |
formatted_data.append(formatted_row) | |
# Determine the minimum and maximum number of elements that a column must have | |
min_elements = len(formatted_data) * 0.5 # For example, at least half the rows | |
max_elements = len(formatted_data) * 1.5 # For example, no more than 1.5 times the number of rows | |
# Transpose the data to work with columns | |
transposed_data = list(map(list, zip(*formatted_data))) | |
# Filter out outlier columns | |
filtered_columns = [col for col in transposed_data if min_elements <= len(col) <= max_elements] | |
# Transpose the data back to the original format | |
filtered_data = list(map(list, zip(*filtered_columns))) | |
# Handle uneven rows by ensuring each row has the same number of columns | |
max_columns = max(len(row) for row in filtered_data) | |
for row in filtered_data: | |
while len(row) < max_columns: | |
row.append("") # Add empty strings to fill the row | |
return filtered_data | |
#========================================================================================================== | |
def ID_MIDI(input_midi): | |
print('*' * 70) | |
print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
start_time = time.time() | |
print('=' * 70) | |
print('Loading MIDI...') | |
fn = os.path.basename(input_midi) | |
fn1 = fn.split('.')[0] | |
fdata = open(input_midi, 'rb').read() | |
input_midi_md5hash = hashlib.md5(fdata).hexdigest() | |
print('=' * 70) | |
print('Requested settings:') | |
print('=' * 70) | |
print('Input MIDI file name:', fn) | |
print('Input MIDI md5 hash', input_midi_md5hash) | |
print('=' * 70) | |
print('Processing MIDI...Please wait...') | |
#======================================================= | |
# START PROCESSING | |
new_midi_data = TMIDIX.score2midi(TMIDIX.midi2score(fdata)) | |
new_midi_md5hash = hashlib.md5(new_midi_data).hexdigest() | |
print('New md5 hash:', new_midi_md5hash) | |
print('Done!') | |
print('=' * 70) | |
print('Processing...Please wait...') | |
if new_midi_md5hash in MIDID_database: | |
client = InferenceClient(api_key=HF_TOKEN) | |
prompt = "Please create a summary table for a MIDI file based on the following keywords strings. Please add best possible description and best possible summary fields. Please respond with the table only. Do not say anything else. Thank you." | |
data = MIDID_database[new_midi_md5hash][0]['midi_path'] | |
messages = [ | |
{ | |
"role": "user", | |
"content": prompt + "\n\n" + data | |
} | |
] | |
completion = client.chat.completions.create( | |
#model="Qwen/Qwen2.5-72B-Instruct", | |
model="mistralai/Mistral-Nemo-Instruct-2407", | |
messages=messages, | |
max_tokens=500 | |
) | |
output_str = completion.choices[0].message['content'] | |
output_table_data = format_table_data(output_str) | |
else: | |
output_table_data = [['No matching MIDI ID records found', 'Unknown MIDI', 'Sorry :(']] | |
print('Done!') | |
print('=' * 70) | |
print(output_str) | |
print('=') | |
#======================================================== | |
output_midi_md5 = str(input_midi_md5hash) | |
#======================================================== | |
print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('-' * 70) | |
print('Req execution time:', (time.time() - start_time), 'sec') | |
print('*' * 70) | |
#======================================================== | |
return output_midi_md5, output_table_data | |
#========================================================================================================== | |
if __name__ == "__main__": | |
PDT = timezone('US/Pacific') | |
print('=' * 70) | |
print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT))) | |
print('=' * 70) | |
print('Loading MIDID database...') | |
MIDID_database = TMIDIX.Tegridy_Any_Pickle_File_Reader('MIDID_Basic_Database_CC_BY_NC_SA.pickle') | |
print('=' * 70) | |
app = gr.Blocks() | |
with app: | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Identification</h1>") | |
gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Identify any MIDI in a comprehensive database of 1.42M+ MIDI records</h1>") | |
gr.Markdown("This is a demo for tegridy-tools\n\n" | |
"Please see [tegridy-tools](https://github.com/asigalov61/tegridy-tools) GitHub repo for more information\n\n" | |
) | |
gr.Markdown("## Upload your MIDI") | |
input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath") | |
submit = gr.Button("Identify MIDI", variant="primary") | |
gr.Markdown("## MIDI identification results") | |
output_midi_md5 = gr.Textbox(label="Input MIDI md5 hash") | |
output_MIDID_results_table = gr.Dataframe(label="MIDID results table", wrap=True, col_count=(3, 'dynamic')) | |
run_event = submit.click(ID_MIDI, [input_midi, | |
], | |
[output_midi_md5, | |
output_MIDID_results_table | |
]) | |
app.queue().launch() |