prithivMLmods commited on
Commit
34ddc1f
Β·
verified Β·
1 Parent(s): e5a04df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -61
app.py CHANGED
@@ -1,58 +1,25 @@
1
- from huggingface_hub import InferenceClient
2
  import gradio as gr
 
 
3
  from fpdf import FPDF
4
  import docx
5
 
6
  css = '''
7
- .gradio-container{max-width: 1000px !important}
8
  h1{text-align:center}
9
  footer {
10
  visibility: hidden
11
  }
12
  '''
13
 
14
- client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
15
 
16
- def format_prompt(message, history, system_prompt=None):
17
- prompt = "<s>"
18
- for user_prompt, bot_response in history:
19
- prompt += f"[INST] {user_prompt} [/INST]"
20
- prompt += f" {bot_response}</s> "
21
- if system_prompt:
22
- prompt += f"[SYS] {system_prompt} [/SYS]"
23
- prompt += f"[INST] {message} [/INST]"
24
- return prompt
25
-
26
- def generate(
27
- prompt, history, system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0,
28
- ):
29
- temperature = float(temperature)
30
- if temperature < 1e-2:
31
- temperature = 1e-2
32
- top_p = float(top_p)
33
-
34
- generate_kwargs = dict(
35
- temperature=temperature,
36
- max_new_tokens=max_new_tokens,
37
- top_p=top_p,
38
- repetition_penalty=repetition_penalty,
39
- do_sample=True,
40
- seed=42,
41
- )
42
-
43
- formatted_prompt = format_prompt(prompt, history, system_prompt)
44
-
45
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
46
- output = ""
47
-
48
- for response in stream:
49
- output += response.token.text
50
- # Clean up </s> tags from the generated output
51
- output = output.replace("</s>", "")
52
- yield output
53
- return output
54
 
55
- # Save the generated content to a file
56
  def save_file(content, filename, file_format):
57
  if file_format == "pdf":
58
  pdf = FPDF()
@@ -75,30 +42,70 @@ def save_file(content, filename, file_format):
75
  else:
76
  raise ValueError("Unsupported file format")
77
 
78
- def generate_and_save(prompt, history, filename="output", file_format="pdf", system_prompt=None, temperature=0.2, max_new_tokens=1024, top_p=0.95, repetition_penalty=1.0):
79
- generated_text = ""
80
- for output in generate(prompt, history, system_prompt, temperature, max_new_tokens, top_p, repetition_penalty):
81
- generated_text = output
82
- generated_text = generated_text.replace("</s>", "")
83
- saved_file = save_file(generated_text, filename, file_format)
84
- return generated_text, history + [(prompt, generated_text)], saved_file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- demo = gr.Interface(
87
- fn=generate_and_save,
88
- inputs=[
89
- gr.Textbox(placeholder="Type your message here...", label="Chatbot", lines=4),
90
- gr.State(value=[]),
91
- gr.Textbox(placeholder="Filename (default: output)", label="Filename", value="output"),
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  gr.Radio(["pdf", "docx", "txt"], label="File Format", value="pdf"),
93
  ],
94
  outputs=[
95
- gr.Textbox(label="Generated Text", lines=4),
96
- gr.State(value=[]),
97
- gr.File(label="Download File")
98
  ],
99
  css=css,
100
- title="GRAB DOC",
101
- theme="bethecloud/storj_theme"
102
  )
103
 
104
- demo.queue().launch(show_api=False)
 
 
 
1
  import gradio as gr
2
+ from openai import OpenAI
3
+ import os
4
  from fpdf import FPDF
5
  import docx
6
 
7
  css = '''
8
+ .gradio-container{max-width: 890px !important}
9
  h1{text-align:center}
10
  footer {
11
  visibility: hidden
12
  }
13
  '''
14
 
15
+ ACCESS_TOKEN = os.getenv("HF_TOKEN")
16
 
17
+ client = OpenAI(
18
+ base_url="https://api-inference.huggingface.co/v1/",
19
+ api_key=ACCESS_TOKEN,
20
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
+ # Function to save generated text to a file
23
  def save_file(content, filename, file_format):
24
  if file_format == "pdf":
25
  pdf = FPDF()
 
42
  else:
43
  raise ValueError("Unsupported file format")
44
 
45
+ # Respond function with file saving
46
+ def respond(
47
+ message,
48
+ history: list[tuple[str, str]],
49
+ system_message,
50
+ max_tokens,
51
+ temperature,
52
+ top_p,
53
+ filename,
54
+ file_format
55
+ ):
56
+ messages = [{"role": "system", "content": system_message}]
57
+
58
+ for val in history:
59
+ if val[0]:
60
+ messages.append({"role": "user", "content": val[0]})
61
+ if val[1]:
62
+ messages.append({"role": "assistant", "content": val[1]})
63
+
64
+ messages.append({"role": "user", "content": message})
65
+
66
+ response = ""
67
+
68
+ for message in client.chat.completions.create(
69
+ model="meta-llama/Meta-Llama-3.1-70B-Instruct",
70
+ max_tokens=max_tokens,
71
+ stream=True,
72
+ temperature=temperature,
73
+ top_p=top_p,
74
+ messages=messages,
75
+ ):
76
+ token = message.choices[0].delta.content
77
+ response += token
78
+ yield response
79
 
80
+ # Save the final response to the specified file format
81
+ saved_file = save_file(response, filename, file_format)
82
+ yield response, history + [(message, response)], saved_file
83
+
84
+ # Gradio interface
85
+ demo = gr.ChatInterface(
86
+ fn=respond,
87
+ additional_inputs=[
88
+ gr.Textbox(value="", label="System message"),
89
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
90
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
91
+ gr.Slider(
92
+ minimum=0.1,
93
+ maximum=1.0,
94
+ value=0.95,
95
+ step=0.05,
96
+ label="Top-P",
97
+ ),
98
+ gr.Textbox(value="output", label="Filename"),
99
  gr.Radio(["pdf", "docx", "txt"], label="File Format", value="pdf"),
100
  ],
101
  outputs=[
102
+ gr.Textbox(label="Generated Text"),
103
+ gr.State(value=[]), # history
104
+ gr.File(label="Download File"),
105
  ],
106
  css=css,
107
+ theme="allenai/gradio-theme",
 
108
  )
109
 
110
+ if __name__ == "__main__":
111
+ demo.launch()