Akjava commited on
Commit
8fd5823
Β·
1 Parent(s): d665e1b
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -24,12 +24,17 @@ print(model_id,device,dtype)
24
  histories = []
25
  #model = None
26
 
 
 
 
 
 
27
 
28
  @spaces.GPU(duration=120)
29
  def generate_text(messages):
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
32
- )
33
 
34
  text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device) #pipeline has not to(device)
35
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
@@ -48,7 +53,7 @@ def generate_text(messages):
48
 
49
 
50
  def call_generate_text(message, history):
51
- history.append({"role": "assistant", "content": message})
52
  print(message)
53
  print(history)
54
 
@@ -63,7 +68,7 @@ def call_generate_text(message, history):
63
 
64
  return ""
65
 
66
- demo = gr.ChatInterface(call_generate_text)
67
 
68
  if __name__ == "__main__":
69
- demo.launch()
 
24
  histories = []
25
  #model = None
26
 
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
29
+ )
30
+ text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device) #pipeline has not to(device)
31
+
32
 
33
  @spaces.GPU(duration=120)
34
  def generate_text(messages):
35
+ # model = AutoModelForCausalLM.from_pretrained(
36
+ # model_id, token=huggingface_token ,torch_dtype=dtype,device_map=device
37
+ # )
38
 
39
  text_generator = pipeline("text-generation", model=model, tokenizer=tokenizer,torch_dtype=dtype,device_map=device) #pipeline has not to(device)
40
  result = text_generator(messages, max_new_tokens=256, do_sample=True, temperature=0.7)
 
53
 
54
 
55
  def call_generate_text(message, history):
56
+ history.append({"role": "user", "content": message})
57
  print(message)
58
  print(history)
59
 
 
68
 
69
  return ""
70
 
71
+ demo = gr.ChatInterface(call_generate_text,type="messages")
72
 
73
  if __name__ == "__main__":
74
+ demo.launch(share=True)