miracFence commited on
Commit
d7db62d
Β·
verified Β·
1 Parent(s): 3df4d21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -1,8 +1,12 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
  import torch
4
  import spaces
5
 
 
 
 
 
6
  # Define quantization configuration
7
  quantization_config = BitsAndBytesConfig(
8
  load_in_4bit=True, # Specify 4-bit quantization
@@ -14,6 +18,7 @@ quantization_config = BitsAndBytesConfig(
14
  # Load the tokenizer and quantized model from Hugging Face
15
  model_name = "llSourcell/medllama2_7b"
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
17
 
18
  # Load model with quantization
19
  model = AutoModelForCausalLM.from_pretrained(model_name,
@@ -29,24 +34,49 @@ def format_history(msg: str, history: list[list[str, str]], system_prompt: str):
29
  return chat_history
30
 
31
  @spaces.GPU(duration=90)
32
- def generate_response(msg: str, history: list[list[str, str]], system_prompt: str):
 
 
 
 
 
 
 
33
  chat_history = format_history(msg, history, system_prompt)
34
 
35
  # Tokenize the input prompt
36
  inputs = tokenizer(chat_history, return_tensors="pt").to("cuda")
 
37
 
38
  # Generate a response using the model
39
- outputs = model.generate(inputs["input_ids"], max_length=500, pad_token_id=tokenizer.eos_token_id)
40
 
41
  # Decode the response back to a string
42
- response = tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Yield the generated response
45
- yield response
 
 
 
 
46
 
47
  # Define the Gradio ChatInterface
48
  chatbot = gr.ChatInterface(
49
- generate_response,
50
  chatbot=gr.Chatbot(
51
  height="64vh"
52
  ),
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
3
  import torch
4
  import spaces
5
 
6
+ import os
7
+ from threading import Thread
8
+ from typing import Iterator
9
+
10
  # Define quantization configuration
11
  quantization_config = BitsAndBytesConfig(
12
  load_in_4bit=True, # Specify 4-bit quantization
 
18
  # Load the tokenizer and quantized model from Hugging Face
19
  model_name = "llSourcell/medllama2_7b"
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
 
23
  # Load model with quantization
24
  model = AutoModelForCausalLM.from_pretrained(model_name,
 
34
  return chat_history
35
 
36
  @spaces.GPU(duration=90)
37
+ def generate(msg: str,
38
+ history: list[list[str, str]],
39
+ system_prompt: str,
40
+ max_new_tokens: int = 1024,
41
+ temperature: float = 0.6,
42
+ top_p: float = 0.9,
43
+ top_k: int = 50,
44
+ repetition_penalty: float = 1.2,) -> Iterator[str]:
45
  chat_history = format_history(msg, history, system_prompt)
46
 
47
  # Tokenize the input prompt
48
  inputs = tokenizer(chat_history, return_tensors="pt").to("cuda")
49
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
50
 
51
  # Generate a response using the model
52
+ # outputs = model.generate(inputs["input_ids"], max_length=500, pad_token_id=tokenizer.eos_token_id)
53
 
54
  # Decode the response back to a string
55
+ # response = tokenizer.decode(outputs[:, inputs["input_ids"].shape[-1]:][0], skip_special_tokens=True)
56
+ generate_kwargs = dict(
57
+ {"input_ids": input_ids},
58
+ streamer=streamer,
59
+ max_new_tokens=max_new_tokens,
60
+ do_sample=True,
61
+ top_p=top_p,
62
+ top_k=top_k,
63
+ temperature=temperature,
64
+ num_beams=1,
65
+ repetition_penalty=repetition_penalty,
66
+ )
67
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
68
+ t.start()
69
 
70
  # Yield the generated response
71
+ #yield response
72
+ outputs = []
73
+ for text in streamer:
74
+ outputs.append(text)
75
+ yield "".join(outputs)
76
 
77
  # Define the Gradio ChatInterface
78
  chatbot = gr.ChatInterface(
79
+ fn=generate,
80
  chatbot=gr.Chatbot(
81
  height="64vh"
82
  ),