generate.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. import json
  3. from argparse import ArgumentParser
  4. from typing import List
  5. import torch
  6. import torch.distributed as dist
  7. from transformers import AutoTokenizer
  8. from safetensors.torch import load_model
  9. from model import Transformer, ModelArgs
  10. def sample(logits, temperature: float = 1.0):
  11. """
  12. Samples a token from the logits using temperature scaling.
  13. Args:
  14. logits (torch.Tensor): The logits tensor for token predictions.
  15. temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
  16. Returns:
  17. torch.Tensor: The sampled token.
  18. """
  19. logits = logits / max(temperature, 1e-5)
  20. probs = torch.softmax(logits, dim=-1)
  21. return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
  22. @torch.inference_mode()
  23. def generate(
  24. model: Transformer,
  25. prompt_tokens: List[List[int]],
  26. max_new_tokens: int,
  27. eos_id: int,
  28. temperature: float = 1.0
  29. ) -> List[List[int]]:
  30. """
  31. Generates new tokens based on the given prompt tokens using the specified model.
  32. Args:
  33. model (Transformer): The transformer model used for token generation.
  34. prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
  35. max_new_tokens (int): The maximum number of new tokens to generate.
  36. eos_id (int): The end-of-sequence token ID.
  37. temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
  38. Returns:
  39. List[List[int]]: A list of lists containing the generated tokens for each sequence.
  40. """
  41. prompt_lens = [len(t) for t in prompt_tokens]
  42. assert max(prompt_lens) <= model.max_seq_len
  43. total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
  44. tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
  45. for i, t in enumerate(prompt_tokens):
  46. tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
  47. prev_pos = 0
  48. finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
  49. prompt_mask = tokens != -1
  50. for cur_pos in range(min(prompt_lens), total_len):
  51. logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
  52. if temperature > 0:
  53. next_token = sample(logits, temperature)
  54. else:
  55. next_token = logits.argmax(dim=-1)
  56. next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
  57. tokens[:, cur_pos] = next_token
  58. finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
  59. prev_pos = cur_pos
  60. if finished.all():
  61. break
  62. completion_tokens = []
  63. for i, toks in enumerate(tokens.tolist()):
  64. toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
  65. if eos_id in toks:
  66. toks = toks[:toks.index(eos_id)]
  67. completion_tokens.append(toks)
  68. return completion_tokens
  69. def main(
  70. ckpt_path: str,
  71. config: str,
  72. input_file: str = "",
  73. interactive: bool = True,
  74. max_new_tokens: int = 100,
  75. temperature: float = 1.0,
  76. ) -> None:
  77. """
  78. Main function to load the model and perform interactive or batch text generation.
  79. Args:
  80. ckpt_path (str): Path to the model checkpoint directory.
  81. config (str): Path to the model configuration file.
  82. input_file (str, optional): Path to a file containing input prompts. Defaults to "".
  83. interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
  84. max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
  85. temperature (float, optional): Temperature for sampling. Defaults to 1.0.
  86. """
  87. world_size = int(os.getenv("WORLD_SIZE", "1"))
  88. rank = int(os.getenv("RANK", "0"))
  89. local_rank = int(os.getenv("LOCAL_RANK", "0"))
  90. if world_size > 1:
  91. dist.init_process_group("nccl")
  92. global print
  93. if rank != 0:
  94. print = lambda *_, **__: None
  95. torch.cuda.set_device(local_rank)
  96. torch.set_default_dtype(torch.bfloat16)
  97. torch.set_num_threads(8)
  98. torch.manual_seed(965)
  99. with open(config) as f:
  100. args = ModelArgs(**json.load(f))
  101. print(args)
  102. with torch.device("cuda"):
  103. model = Transformer(args)
  104. tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
  105. tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
  106. load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
  107. if interactive:
  108. messages = []
  109. while True:
  110. if world_size == 1:
  111. prompt = input(">>> ")
  112. elif rank == 0:
  113. prompt = input(">>> ")
  114. objects = [prompt]
  115. dist.broadcast_object_list(objects, 0)
  116. else:
  117. objects = [None]
  118. dist.broadcast_object_list(objects, 0)
  119. prompt = objects[0]
  120. if prompt == "/exit":
  121. break
  122. elif prompt == "/clear":
  123. messages.clear()
  124. continue
  125. messages.append({"role": "user", "content": prompt})
  126. prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
  127. completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
  128. completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
  129. print(completion)
  130. messages.append({"role": "assistant", "content": completion})
  131. else:
  132. with open(input_file) as f:
  133. prompts = [line.strip() for line in f.readlines()]
  134. assert len(prompts) <= args.max_batch_size
  135. prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
  136. completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
  137. completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
  138. for prompt, completion in zip(prompts, completions):
  139. print("Prompt:", prompt)
  140. print("Completion:", completion)
  141. print()
  142. if world_size > 1:
  143. dist.destroy_process_group()
  144. if __name__ == "__main__":
  145. """
  146. Command-line interface for distributed text generation.
  147. Arguments:
  148. --ckpt-path (str): Path to the model checkpoint directory.
  149. --config (str): Path to the model configuration file.
  150. --input-file (str, optional): File containing prompts for batch processing.
  151. --interactive (bool, optional): Enable interactive mode for generating text.
  152. --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
  153. --temperature (float, optional): Temperature for sampling. Defaults to 0.2.
  154. Raises:
  155. AssertionError: If neither input-file nor interactive mode is specified.
  156. """
  157. parser = ArgumentParser()
  158. parser.add_argument("--ckpt-path", type=str, required=True)
  159. parser.add_argument("--config", type=str, required=True)
  160. parser.add_argument("--input-file", type=str, default="")
  161. parser.add_argument("--interactive", action="store_true")
  162. parser.add_argument("--max-new-tokens", type=int, default=200)
  163. parser.add_argument("--temperature", type=float, default=0.2)
  164. args = parser.parse_args()
  165. assert args.input_file or args.interactive
  166. main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)