File size: 2,991 Bytes
86e2a34 12476bc 437cdee 86e2a34 12476bc 86e2a34 ce847b0 3a1a0ef 86e2a34 3a1a0ef 5362bf1 0fdb2f3 3a1a0ef 5362bf1 2d4b9ba c1ef9c2 5362bf1 0fdb2f3 5362bf1 0fdb2f3 5362bf1 0fdb2f3 5362bf1 86e2a34 5362bf1 54f41e4 5362bf1 54f41e4 0fdb2f3 12476bc 0fdb2f3 5362bf1 86e2a34 3f1d57b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
from transformers.utils import logging
import gradio as gr
# Define the logger instance for the transformers library
logger = logging.get_logger("transformers")
# Load the model and tokenizer
model_name = "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" #"openai-community/gpt2" or "TheBloke/Mistral-7B-Instruct-v0.1-GPTQ" or "TheBloke/Llama-2-7B-Chat-GGML" or "TheBloke/zephyr-7B-beta-GPTQ"
tokenizer = AutoTokenizer.from_pretrained(model_name,use_fast=True)
#model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,device_map="auto",trust_remote_code=False,revision="main")
# Generate text using the model and tokenizer
def generate_text(input_text):
input_ids = tokenizer.encode(input_text, return_tensors="pt")
#attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
output = model.generate(input_ids, max_new_tokens=512, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)# attention_mask=attention_mask, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95, temperature=0.7, do_sample=True)
return tokenizer.decode(output[0])
# Example of disabling Exllama backend (if applicable in your configuration)
config = {"disable_exllama": True}
model.config.update(config)
# def generate_text(prompt):
# inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512, padding="max_length")
# summary_ids = model.generate(inputs["input_ids"], max_new_tokens=512, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
# return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
# #for training the model after the data is collected
# #model.save_pretrained("model")
# #tokenizer.save_pretrained("model")
# #for the app functions
# def show_output_text(message):
# history.append((message,""))
# story = generate_text(message)
# history[-1] = (message,story)
# return story
# def clear_textbox():
# return None,None
# # Créer une interface de saisie avec Gradio
interface = gr.Interface(fn=generate_text, inputs="text", outputs="text",title="TeLLMyStory",description="Enter your story idea and the model will generate the story based on it.")
# with gr.Blocks() as demo:
# gr.Markdown("TeLLMyStory chatbot")
# with gr.Row():
# input_text = gr.Textbox(label="Enter your story idea here", placeholder="Once upon a time...")
# clear_button = gr.Button("Clear",variant="secondary")
# submit_button = gr.Button("Submit", variant="primary")
# with gr.Row():
# gr.Markdown("And see the story take shape here")
# output_text = gr.Textbox(label="History")
# submit_button.click(fn=show_output_text, inputs=input_text,outputs=output_text)
# clear_button.click(fn=clear_textbox,outputs=[input_text,output_text])
# # Lancer l'interface
interface.launch()
|