File size: 1,653 Bytes
9d46201
 
 
 
4ab88cc
0d61d97
4ab88cc
 
 
 
bfd1de9
4ab88cc
4a01ccd
 
 
 
a1b1972
4a01ccd
 
 
4ab88cc
292ac90
bfd1de9
9d46201
292ac90
4a01ccd
 
963e69f
4a01ccd
4cd22cd
 
9d46201
4a01ccd
9d46201
 
183cde3
0502abd
9d46201
 
 
 
 
 
f43cb7c
 
 
 
 
9d46201
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import streamlit as st
from streamlit.elements.altair import generate_chart
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import pipeline

@st.cache(allow_output_mutation=True)
def load_model():
  model_ckpt = "flax-community/gpt2-rap-lyric-generator"
  tokenizer = AutoTokenizer.from_pretrained(model_ckpt,from_flax=True)
  model = AutoModelForCausalLM.from_pretrained(model_ckpt,from_flax=True)
  return tokenizer, model

@st.cache()
def load_rappers():
  text_file = open("rappers.txt")
  rappers = text_file.readlines()
  rappers = [name[:-1] for name in rappers]
  rappers.sort()
  return rappers


title = st.title("Loading model")
tokenizer, model = load_model()
text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
title.title("Rap lyrics generator")
#artist = st.text_input("Enter the artist", "Wu-Tang Clan")
list_of_rappers = load_rappers()
artist = st.selectbox("Choose your rapper", tuple(list_of_rappers))
song_name = st.text_input("Enter the desired song name", "Shaolin")



if st.button("Generate lyrics", help="Press me!"):
    st.title(f"{artist}: {song_name}")
    prefix_text = f"<BOS>{song_name} [Verse 1:{artist}]"
    generated_song = text_generation(prefix_text, max_length=750, do_sample=True)[0]
    for count, line in enumerate(generated_song['generated_text'].split("\n")):
      if count == 0:
        st.write(line[line.find('['):])
        continue
      if "<BOS>" in line:
        st.write(line[5:])
        continue
      if line.startswith("["):
        st.markdown(f"**{line}**")
        continue
      if"<EOS>" in line:
        break
      st.write(line)