File size: 2,991 Bytes
86e2a34
12476bc
437cdee
 
86e2a34
12476bc
 
86e2a34
ce847b0
3a1a0ef
86e2a34
3a1a0ef
5362bf1
0fdb2f3
 
 
3a1a0ef
 
 
5362bf1
2d4b9ba
 
 
 
c1ef9c2
 
 
 
 
5362bf1
 
 
 
 
 
0fdb2f3
5362bf1
 
 
0fdb2f3
5362bf1
0fdb2f3
 
5362bf1
 
86e2a34
5362bf1
 
 
54f41e4
 
 
5362bf1
 
54f41e4
0fdb2f3
12476bc
0fdb2f3
 
5362bf1
86e2a34
3f1d57b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers.utils import logging
import gradio as gr

# Define the logger instance for the transformers library
logger = logging.get_logger("transformers")

# Load the model and tokenizer
model_name = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" #"openai-community/gpt2" or "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" or "TheBloke/Llama-2-7B-Chat-GGML" or "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name) 
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto",trust_remote_code=False,revision="main")

# Generate text using the model and tokenizer
def generate_text(input_text):
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    #attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
    output = model.generate(input_ids, max_new_tokens=512, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)# attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
    return tokenizer.decode(output[0])

# Example of disabling Exllama backend (if applicable in your configuration)
config = {"disable_exllama": True}
model.config.update(config)

# def generate_text(prompt):
#     inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
#     summary_ids = model.generate(inputs["input_ids"], max_new_tokens=512, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
#     return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# #for training the model after the data is collected
# #model.save_pretrained("model")
# #tokenizer.save_pretrained("model")

# #for the app functions

# def show_output_text(message):
#     history.append((message,""))
#     story = generate_text(message)
#     history[-1] = (message,story)
#     return story

# def clear_textbox():
#     return None,None

# # Créer une interface de saisie avec Gradio
interface = gr.Interface(fn=generate_text, inputs="text", outputs="text",title="TeLLMyStory",description="Enter your story idea and the model will generate the story based on it.")
# with gr.Blocks() as demo:
#     gr.Markdown("TeLLMyStory chatbot")
#     with gr.Row():
#         input_text = gr.Textbox(label="Enter your story idea here", placeholder="Once upon a time...")
#         clear_button = gr.Button("Clear",variant="secondary")
#         submit_button = gr.Button("Submit", variant="primary")

#     with gr.Row():
#         gr.Markdown("And see the story take shape here")
#         output_text = gr.Textbox(label="History")
    
#     submit_button.click(fn=show_output_text, inputs=input_text,outputs=output_text)
#     clear_button.click(fn=clear_textbox,outputs=[input_text,output_text])
# # Lancer l'interface
interface.launch()