| 55 | |
| 56 | |
| 57 | class GeminiDecoder(DecoderBase): |
| 58 | def __init__(self, name: str, **kwargs): |
| 59 | super().__init__(name, **kwargs) |
| 60 | genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) |
| 61 | self.client = genai.GenerativeModel(name) |
| 62 | |
| 63 | def codegen( |
| 64 | self, prompt: str, do_sample: bool = True, num_samples: int = 200 |
| 65 | ) -> List[str]: |
| 66 | if do_sample: |
| 67 | assert self.temperature > 0, "Temperature must be positive for sampling" |
| 68 | batch_size = min(self.batch_size, num_samples, 8) |
| 69 | message = self.instruction_prefix + f"\n```python\n{prompt.strip()}\n```" |
| 70 | replies = make_auto_request( |
| 71 | self.client, |
| 72 | [{"role": "user", "content": message}], |
| 73 | n=batch_size, |
| 74 | temperature=self.temperature, |
| 75 | max_new_tokens=self.max_new_tokens, |
| 76 | ) |
| 77 | |
| 78 | if len(replies.candidates) != batch_size: |
| 79 | print( |
| 80 | f"WARNING: Expected {batch_size} outputs but got {len(replies.candidates)}" |
| 81 | ) |
| 82 | |
| 83 | ret_texts = [] |
| 84 | for candidate in replies.candidates: |
| 85 | parts = candidate.content.parts |
| 86 | if parts: |
| 87 | ret_texts.append(parts[0].text) |
| 88 | else: |
| 89 | print("Empty response!") |
| 90 | ret_texts.append("") |
| 91 | print(f"{candidate.safety_ratings = }") |
| 92 | |
| 93 | return ret_texts + [""] * (batch_size - len(ret_texts)) |
| 94 | |
| 95 | def is_direct_completion(self) -> bool: |
| 96 | return False |