Spaces:
Runtime error
Runtime error
File size: 3,536 Bytes
2aea843 6dcdc90 d36aee0 252de57 d36aee0 252de57 2aea843 98fd48f 0634305 252de57 8987d5f 252de57 8987d5f 252de57 8987d5f 252de57 69830ca 42ac4a1 252de57 df889e0 69830ca df889e0 c0afea8 252de57 d36aee0 a3015fe 1376587 0e9349e 1376587 3b63e29 92a6dd7 1376587 a4f154b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import os
import time
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
# Loading the tokenizer and model from Hugging Face's model hub.
tokenizer = AutoTokenizer.from_pretrained("soketlabs/pragna-1b", token=os.environ.get('HF_TOKEN'))
model = AutoModelForCausalLM.from_pretrained(
"soketlabs/pragna-1b",
token=os.environ.get('HF_TOKEN'),
revision='3c5b8b1309f7d89710331ba2f164570608af0de7'
)
model.load_adapter('soketlabs/pragna-1b-it-v0.1', token=os.environ.get('HF_TOKEN'))
# using CUDA for an optimal experience
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
# Defining a custom stopping criteria class for the model's text generation.
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
stop_ids = [2] # IDs of tokens where the generation should stop.
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id: # Checking if the last generated token is a stop token.
return True
return False
# Function to generate model predictions.
def predict(message, history):
history_transformer_format = history + [[message, ""]]
stop = StopOnTokens()
sys_prompt = 'You are Pragna, an AI built by Soket AI Labs. You should never lie and always tell facts. Help the user as much as you can and be open to say I dont know this if you are not sure of the answer'
eos_token = tokenizer.eos_token
messages = f'<|system|>\n{sys_prompt}{eos_token}'
# Formatting the input for the model.
messages += "</s>".join(["</s>".join(["<|user|>\n" + item[0], "<|assistant|>\n" + item[1]])
for item in history_transformer_format])
print(messages)
model_inputs = tokenizer([messages], return_tensors="pt").to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
model_inputs,
streamer=streamer,
# max_new_tokens=300,
# do_sample=True,
# top_p=0.95,
# top_k=50,
# temperature=0.3,
# repetition_penalty=10.,
# num_beams=1,
max_new_tokens=300,
do_sample=True,
top_k=5,
num_beams=1,
use_cache=False,
temperature=0.4,
repetition_penalty=1.2,
stopping_criteria=StoppingCriteriaList([stop])
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start() # Starting the generation in a separate thread.
partial_message = ""
for new_token in streamer:
partial_message += new_token
if '</s>' in partial_message: # Breaking the loop if the stop token is generated.
break
yield partial_message
def slow_echo(message, history):
for i in range(len(message)):
time.sleep(0.05)
yield "You typed: " + message[: i+1]
demo = gr.ChatInterface(
predict,
chatbot=gr.Chatbot(height=300),
textbox=gr.Textbox(placeholder="Try Pragna SFT", container=False, scale=7),
title="pragna-1b-it",
description="This is pragna-1b-it",
theme="soft",
examples=['Hi!'],
cache_examples=False,
retry_btn=None,
undo_btn="Delete Previous",
clear_btn="Clear",
).queue()
if __name__ == "__main__":
demo.launch() |