Spaces:
Runtime error
Runtime error
File size: 2,053 Bytes
03dc554 7f692e6 3382226 2ac2f15 7f692e6 03dc554 2ac2f15 3382226 7f692e6 2ac2f15 3382226 2ac2f15 3382226 03dc554 2ac2f15 7f692e6 2ac2f15 3382226 7f692e6 2ac2f15 7f692e6 2ac2f15 7f692e6 2ac2f15 7f692e6 2ac2f15 7f692e6 3382226 2ac2f15 7f692e6 2ac2f15 7f692e6 3382226 7f692e6 2ac2f15 7f692e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import os
import gradio as gr
from unsloth import FastLanguageModel
import torch
# Disable CUDA explicitly by setting the environment variable
os.environ["CUDA_VISIBLE_DEVICES"] = "" # Disabling CUDA
# Set device to CPU
device = torch.device("cpu")
model_name_or_path = "michailroussos/model_llama_8d"
max_seq_length = 2048
dtype = None
# Load the model on CPU
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name_or_path, # Your model path
max_seq_length=max_seq_length,
dtype=dtype,
load_in_4bit=True,
).to(device) # Ensure the model is on CPU
# Enable native faster inference if possible
FastLanguageModel.for_inference(model)
# Define the inference function
def respond(message, history, system_message, max_tokens, temperature, top_p):
messages = [{"role": "system", "content": system_message}]
for val in history:
if val[0]:
messages.append({"role": "user", "content": val[0]})
if val[1]:
messages.append({"role": "assistant", "content": val[1]})
messages.append({"role": "user", "content": message})
response = ""
# Perform inference on CPU
inputs = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(device)
for message in model.generate(input_ids=inputs['input_ids'], streamer=None, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p):
token = message.choices[0].delta.content
response += token
yield response
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Launch Gradio app
if __name__ == "__main__":
demo.launch()
|