File size: 6,593 Bytes
c10c249
cc4a711
c10c249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c5824b7
 
c10c249
 
 
 
2dbbde3
 
 
 
c5824b7
6387673
c5824b7
 
 
32af644
 
 
 
 
 
 
6387673
32af644
 
6387673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32af644
 
c5824b7
6387673
c5824b7
 
 
 
c10c249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7724541
c10c249
7724541
c10c249
7724541
c10c249
 
 
 
 
7724541
c10c249
c5824b7
 
 
5e34946
7724541
c5824b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c10c249
c5824b7
c10c249
41705b0
c10c249
 
 
4d700d3
 
 
c10c249
 
 
c5824b7
 
c10c249
 
 
 
 
 
 
 
 
c5824b7
c10c249
 
 
 
 
 
 
 
 
 
 
c5824b7
 
 
 
 
b49f486
c10c249
 
 
cc4a711
41705b0
c10c249
cc4a711
c10c249
 
 
 
 
 
 
5015e4e
c10c249
cc4a711
c10c249
c5824b7
27814ce
c10c249
cc4a711
c10c249
 
c5824b7
c10c249
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
#================================================================
# https://huggingface.co/spaces/asigalov61/MIDI-Identification
#================================================================

import os
import hashlib

import time
import datetime
from pytz import timezone

import copy
from collections import Counter
import random
import statistics

import gradio as gr

from huggingface_hub import InferenceClient

import TMIDIX

#==========================================================================================================

HF_TOKEN = os.getenv('HF_TOKEN')

#==========================================================================================================

def format_table_data(data_string):
    
    # Split the string into rows based on newlines
    rows = data_string.strip().split("\n")

    # Initialize a list to store the formatted data
    formatted_data = []

    for row in rows:
        # Split each row into columns based on the separator '|' and strip extra spaces
        columns = row.split("|")
        formatted_row = [cell.strip() for cell in columns]

        formatted_data.append(formatted_row)
    
    # Determine the minimum and maximum number of elements that a column must have
    min_elements = len(formatted_data) * 0.5  # For example, at least half the rows
    max_elements = len(formatted_data) * 1.5  # For example, no more than 1.5 times the number of rows

    # Transpose the data to work with columns
    transposed_data = list(map(list, zip(*formatted_data)))

    # Filter out outlier columns
    filtered_columns = [col for col in transposed_data if min_elements <= len(col) <= max_elements]

    # Transpose the data back to the original format
    filtered_data = list(map(list, zip(*filtered_columns)))

    # Handle uneven rows by ensuring each row has the same number of columns
    max_columns = max(len(row) for row in filtered_data)
    for row in filtered_data:
        while len(row) < max_columns:
            row.append("")  # Add empty strings to fill the row

    return filtered_data

#==========================================================================================================

def ID_MIDI(input_midi):
    
    print('*' * 70)
    print('Req start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    start_time = time.time()
    
    print('=' * 70)
    print('Loading MIDI...')

    fn = os.path.basename(input_midi)
    fn1 = fn.split('.')[0]

    fdata = open(input_midi, 'rb').read()

    input_midi_md5hash = hashlib.md5(fdata).hexdigest()
    
    print('=' * 70)
    print('Requested settings:')
    print('=' * 70)
    print('Input MIDI file name:', fn)
    print('Input MIDI md5 hash', input_midi_md5hash)
    print('=' * 70)
    print('Processing MIDI...Please wait...')
    
    #=======================================================
    # START PROCESSING

    new_midi_data = TMIDIX.score2midi(TMIDIX.midi2score(fdata))
    
    new_midi_md5hash = hashlib.md5(new_midi_data).hexdigest()

    print('New md5 hash:', new_midi_md5hash)
    
    print('Done!')
    print('=' * 70)
    print('Processing...Please wait...')

    if new_midi_md5hash in MIDID_database:

        client = InferenceClient(api_key=HF_TOKEN)
        
        prompt = "Please create a summary table for a MIDI file based on the following keywords strings. Please add best possible description and best possible summary fields. Please respond with the table only. Do not say anything else. Thank you."

        data = MIDID_database[new_midi_md5hash][0]['midi_path']
        
        messages = [
        	{
        		"role": "user",
        		"content": prompt + "\n\n" + data
        	}
        ]
    
        completion = client.chat.completions.create(
            #model="Qwen/Qwen2.5-72B-Instruct",
            model="mistralai/Mistral-Nemo-Instruct-2407",
        	messages=messages, 
        	max_tokens=500
        )
    
        output_str = completion.choices[0].message['content']
    
        output_table_data = format_table_data(output_str)

    else:

        output_table_data = [['No matching MIDI ID records found', 'Unknown MIDI', 'Sorry :(']]
    
    print('Done!')
    print('=' * 70)
    
    print(output_str)
    print('=')

    #========================================================

    output_midi_md5 = str(input_midi_md5hash)
   
    #========================================================

    print('Req end time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('-' * 70)
    print('Req execution time:', (time.time() - start_time), 'sec')
    print('*' * 70)
    
    #========================================================
    
    return output_midi_md5, output_table_data
    
#==========================================================================================================

if __name__ == "__main__":

    PDT = timezone('US/Pacific')
    
    print('=' * 70)
    print('App start time: {:%Y-%m-%d %H:%M:%S}'.format(datetime.datetime.now(PDT)))
    print('=' * 70)

    print('Loading MIDID database...')
    MIDID_database = TMIDIX.Tegridy_Any_Pickle_File_Reader('MIDID_Basic_Database_CC_BY_NC_SA.pickle')
    
    print('=' * 70)

    app = gr.Blocks()
    
    with app:
        
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>MIDI Identification</h1>")
        gr.Markdown("<h1 style='text-align: center; margin-bottom: 1rem'>Identify any MIDI in a comprehensive database of 1.42M+ MIDI records</h1>")
        
        gr.Markdown("This is a demo for tegridy-tools\n\n"
                    "Please see [tegridy-tools](https://github.com/asigalov61/tegridy-tools) GitHub repo for more information\n\n"
                   )
        
        gr.Markdown("## Upload your MIDI")
        
        input_midi = gr.File(label="Input MIDI", file_types=[".midi", ".mid", ".kar"], type="filepath")

        submit = gr.Button("Identify MIDI", variant="primary")

        gr.Markdown("## MIDI identification results")
        
        output_midi_md5 = gr.Textbox(label="Input MIDI md5 hash")
        output_MIDID_results_table = gr.Dataframe(label="MIDID results table", wrap=True, col_count=(3, 'dynamic'))
        
        run_event = submit.click(ID_MIDI, [input_midi,                                         
                                              ],
                                                [output_midi_md5, 
                                                    output_MIDID_results_table 
                                                ])
        
    app.queue().launch()