File size: 1,933 Bytes
7f04bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import streamlit as st
from streamlit_chat import message
from streamlit_extras.colored_header import colored_header
from streamlit_extras.add_vertical_space import add_vertical_space
from hugchat import hugchat
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("Celestinian/PromptGPT")
model = AutoModelForCausalLM.from_pretrained("Celestinian/PromptGPT")

st.set_page_config(page_title="EinfachChat")

# Sidebar contents
with st.sidebar:
    st.title('EinfachChat')
    max_length = st.slider('Max Length', min_value=10, max_value=100, value=30)
    do_sample = st.checkbox('Do Sample', value=True)
    temperature = st.slider('Temperature', min_value=0.1, max_value=1.0, value=0.4)
    no_repeat_ngram_size = st.slider('No Repeat N-Gram Size', min_value=1, max_value=10, value=1)
    top_k = st.slider('Top K', min_value=1, max_value=100, value=50)
    top_p = st.slider('Top P', min_value=0.1, max_value=1.0, value=0.2)

# Rest of your original Streamlit code ...

def generate_text(prompt, max_length, do_sample, temperature, no_repeat_ngram_size, top_k, top_p):
    formatted_prompt = "\n" + prompt
    if not ',' in prompt:
        formatted_prompt += ','
    prompt = tokenizer(formatted_prompt, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    out = model.generate(**prompt, max_length=max_length, do_sample=do_sample, temperature=temperature,
                          no_repeat_ngram_size=no_repeat_ngram_size, top_k=top_k, top_p=top_p)
    output = tokenizer.decode(out[0])
    clean_output = output.replace('\n', '\n')
    return clean_output

# Inside the conditional display part, replace
# response = generate_response(user_input)
# with
response = generate_text(user_input, max_length, do_sample, temperature, no_repeat_ngram_size, top_k, top_p)