generate.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import os
  2. import json
  3. from argparse import ArgumentParser
  4. from typing import List
  5. import torch
  6. import torch.distributed as dist
  7. from transformers import AutoTokenizer
  8. from safetensors.torch import load_model
  9. from model import Transformer, ModelArgs
  10. def sample(logits, temperature: float = 1.0):
  11. logits = logits / max(temperature, 1e-5)
  12. probs = torch.softmax(logits, dim=-1)
  13. return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
  14. @torch.inference_mode()
  15. def generate(
  16. model: Transformer,
  17. prompt_tokens: List[List[int]],
  18. max_new_tokens: int,
  19. eos_id: int,
  20. temperature: float = 1.0
  21. ) -> List[List[int]]:
  22. prompt_lens = [len(t) for t in prompt_tokens]
  23. assert max(prompt_lens) <= model.max_seq_len
  24. total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
  25. tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
  26. for i, t in enumerate(prompt_tokens):
  27. tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
  28. prev_pos = 0
  29. finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
  30. prompt_mask = tokens != -1
  31. for cur_pos in range(min(prompt_lens), total_len):
  32. logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
  33. if temperature > 0:
  34. next_token = sample(logits, temperature)
  35. else:
  36. next_token = logits.argmax(dim=-1)
  37. next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
  38. tokens[:, cur_pos] = next_token
  39. finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
  40. prev_pos = cur_pos
  41. if finished.all():
  42. break
  43. completion_tokens = []
  44. for i, toks in enumerate(tokens.tolist()):
  45. toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
  46. if eos_id in toks:
  47. toks = toks[:toks.index(eos_id)]
  48. completion_tokens.append(toks)
  49. return completion_tokens
  50. def main(
  51. ckpt_path: str,
  52. config: str,
  53. input_file: str = "",
  54. interactive: bool = True,
  55. max_new_tokens: int = 100,
  56. temperature: float = 1.0,
  57. ) -> None:
  58. world_size = int(os.getenv("WORLD_SIZE", "1"))
  59. rank = int(os.getenv("RANK", "0"))
  60. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  61. if world_size > 1:
  62. dist.init_process_group("nccl")
  63. global print
  64. if rank != 0:
  65. print = lambda *_, **__: None
  66. torch.cuda.set_device(local_rank)
  67. torch.set_default_dtype(torch.bfloat16)
  68. torch.set_num_threads(8)
  69. torch.manual_seed(965)
  70. with open(config) as f:
  71. args = ModelArgs(**json.load(f))
  72. print(args)
  73. with torch.device("cuda"):
  74. model = Transformer(args)
  75. tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
  76. tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
  77. load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
  78. if interactive:
  79. messages = []
  80. while True:
  81. if world_size == 1:
  82. prompt = input(">>> ")
  83. elif rank == 0:
  84. prompt = input(">>> ")
  85. objects = [prompt]
  86. dist.broadcast_object_list(objects, 0)
  87. else:
  88. objects = [None]
  89. dist.broadcast_object_list(objects, 0)
  90. prompt = objects[0]
  91. if prompt == "/exit":
  92. break
  93. elif prompt == "/clear":
  94. messages.clear()
  95. continue
  96. messages.append({"role": "user", "content": prompt})
  97. prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
  98. completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
  99. completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
  100. print(completion)
  101. messages.append({"role": "assistant", "content": completion})
  102. else:
  103. with open(input_file) as f:
  104. prompts = [line.strip() for line in f.readlines()]
  105. assert len(prompts) <= args.max_batch_size
  106. prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
  107. completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
  108. completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
  109. for prompt, completion in zip(prompts, completions):
  110. print("Prompt:", prompt)
  111. print("Completion:", completion)
  112. print()
  113. if world_size > 1:
  114. dist.destroy_process_group()
  115. if __name__ == "__main__":
  116. parser = ArgumentParser()
  117. parser.add_argument("--ckpt-path", type=str, required=True)
  118. parser.add_argument("--config", type=str, required=True)
  119. parser.add_argument("--input-file", type=str, default="")
  120. parser.add_argument("--interactive", action="store_true")
  121. parser.add_argument("--max-new-tokens", type=int, default=200)
  122. parser.add_argument("--temperature", type=float, default=0.2)
  123. args = parser.parse_args()
  124. assert args.input_file or args.interactive
  125. main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)